parent
50d0129461
commit
637a6a3628
@ -1,2 +1,4 @@ |
||||
go.sum |
||||
BUILD |
||||
.DS_Store |
||||
tool/kratos/kratos |
||||
|
@ -0,0 +1,18 @@ |
||||
package flagvar |
||||
|
||||
import ( |
||||
"strings" |
||||
) |
||||
|
||||
// StringVars []string implement flag.Value
|
||||
type StringVars []string |
||||
|
||||
func (s StringVars) String() string { |
||||
return strings.Join(s, ",") |
||||
} |
||||
|
||||
// Set implement flag.Value
|
||||
func (s *StringVars) Set(val string) error { |
||||
*s = append(*s, val) |
||||
return nil |
||||
} |
@ -0,0 +1,48 @@ |
||||
package pb |
||||
|
||||
import ( |
||||
"strconv" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
|
||||
any "github.com/golang/protobuf/ptypes/any" |
||||
) |
||||
|
||||
func (e *Error) Error() string { |
||||
return strconv.FormatInt(int64(e.GetErrCode()), 10) |
||||
} |
||||
|
||||
// Code is the code of error.
|
||||
func (e *Error) Code() int { |
||||
return int(e.GetErrCode()) |
||||
} |
||||
|
||||
// Message is error message.
|
||||
func (e *Error) Message() string { |
||||
return e.GetErrMessage() |
||||
} |
||||
|
||||
// Equal compare whether two errors are equal.
|
||||
func (e *Error) Equal(ec error) bool { |
||||
return ecode.Cause(ec).Code() == e.Code() |
||||
} |
||||
|
||||
// Details return error details.
|
||||
func (e *Error) Details() []interface{} { |
||||
return []interface{}{e.GetErrDetail()} |
||||
} |
||||
|
||||
// From will convert ecode.Codes to pb.Error.
|
||||
//
|
||||
// Deprecated: please use ecode.Error
|
||||
func From(ec ecode.Codes) *Error { |
||||
var detail *any.Any |
||||
if details := ec.Details(); len(details) > 0 { |
||||
detail, _ = details[0].(*any.Any) |
||||
} |
||||
return &Error{ |
||||
ErrCode: int32(ec.Code()), |
||||
ErrMessage: ec.Message(), |
||||
ErrDetail: detail, |
||||
} |
||||
} |
@ -0,0 +1,96 @@ |
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: error.proto
|
||||
|
||||
package pb |
||||
|
||||
import proto "github.com/golang/protobuf/proto" |
||||
import fmt "fmt" |
||||
import math "math" |
||||
import any "github.com/golang/protobuf/ptypes/any" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = fmt.Errorf |
||||
var _ = math.Inf |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
// Deprecated: please use ecode.Error
|
||||
type Error struct { |
||||
ErrCode int32 `protobuf:"varint,1,opt,name=err_code,json=errCode,proto3" json:"err_code,omitempty"` |
||||
ErrMessage string `protobuf:"bytes,2,opt,name=err_message,json=errMessage,proto3" json:"err_message,omitempty"` |
||||
ErrDetail *any.Any `protobuf:"bytes,3,opt,name=err_detail,json=errDetail,proto3" json:"err_detail,omitempty"` |
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
XXX_sizecache int32 `json:"-"` |
||||
} |
||||
|
||||
func (m *Error) Reset() { *m = Error{} } |
||||
func (m *Error) String() string { return proto.CompactTextString(m) } |
||||
func (*Error) ProtoMessage() {} |
||||
func (*Error) Descriptor() ([]byte, []int) { |
||||
return fileDescriptor_error_28aad86a4e53115b, []int{0} |
||||
} |
||||
func (m *Error) XXX_Unmarshal(b []byte) error { |
||||
return xxx_messageInfo_Error.Unmarshal(m, b) |
||||
} |
||||
func (m *Error) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { |
||||
return xxx_messageInfo_Error.Marshal(b, m, deterministic) |
||||
} |
||||
func (dst *Error) XXX_Merge(src proto.Message) { |
||||
xxx_messageInfo_Error.Merge(dst, src) |
||||
} |
||||
func (m *Error) XXX_Size() int { |
||||
return xxx_messageInfo_Error.Size(m) |
||||
} |
||||
func (m *Error) XXX_DiscardUnknown() { |
||||
xxx_messageInfo_Error.DiscardUnknown(m) |
||||
} |
||||
|
||||
var xxx_messageInfo_Error proto.InternalMessageInfo |
||||
|
||||
func (m *Error) GetErrCode() int32 { |
||||
if m != nil { |
||||
return m.ErrCode |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *Error) GetErrMessage() string { |
||||
if m != nil { |
||||
return m.ErrMessage |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *Error) GetErrDetail() *any.Any { |
||||
if m != nil { |
||||
return m.ErrDetail |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterType((*Error)(nil), "err.Error") |
||||
} |
||||
|
||||
func init() { proto.RegisterFile("error.proto", fileDescriptor_error_28aad86a4e53115b) } |
||||
|
||||
var fileDescriptor_error_28aad86a4e53115b = []byte{ |
||||
// 164 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x8d, 0xc1, 0xca, 0x82, 0x40, |
||||
0x14, 0x85, 0x99, 0x5f, 0xfc, 0xcb, 0x71, 0x37, 0xb4, 0xd0, 0x36, 0x49, 0x2b, 0x57, 0x23, 0xe4, |
||||
0x13, 0x44, 0xb5, 0x6c, 0xe3, 0x0b, 0x88, 0xe6, 0x49, 0x02, 0xf3, 0xc6, 0xd1, 0x20, 0xdf, 0x3e, |
||||
0x1c, 0x69, 0x79, 0xcf, 0xf7, 0x71, 0x3f, 0x1d, 0x82, 0x14, 0xda, 0x17, 0x65, 0x14, 0xe3, 0x81, |
||||
0xdc, 0xc6, 0xad, 0x48, 0xdb, 0x21, 0x73, 0x53, 0xfd, 0xbe, 0x67, 0x55, 0x3f, 0x2d, 0x7c, 0xff, |
||||
0xd1, 0xfe, 0x65, 0xd6, 0x4d, 0xac, 0xd7, 0x20, 0xcb, 0x9b, 0x34, 0x88, 0x54, 0xa2, 0x52, 0xbf, |
||||
0x58, 0x81, 0x3c, 0x49, 0x03, 0xb3, 0x73, 0x2f, 0xcb, 0x27, 0x86, 0xa1, 0x6a, 0x11, 0xfd, 0x25, |
||||
0x2a, 0x0d, 0x0a, 0x0d, 0xf2, 0xba, 0x2c, 0x26, 0xd7, 0xf3, 0x55, 0x36, 0x18, 0xab, 0x47, 0x17, |
||||
0x79, 0x89, 0x4a, 0xc3, 0xc3, 0xc6, 0x2e, 0x51, 0xfb, 0x8b, 0xda, 0x63, 0x3f, 0x15, 0x01, 0xc8, |
||||
0xb3, 0xd3, 0xea, 0x7f, 0x07, 0xf2, 0x6f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xf7, 0x41, 0x22, 0xfd, |
||||
0xaf, 0x00, 0x00, 0x00, |
||||
} |
@ -0,0 +1,13 @@ |
||||
syntax = "proto3"; |
||||
|
||||
package pb; |
||||
|
||||
import "google/protobuf/any.proto"; |
||||
|
||||
option go_package = "go-common/library/ecode/pb"; |
||||
|
||||
message Error { |
||||
int32 err_code = 1; |
||||
string err_message = 2; |
||||
google.protobuf.Any err_detail = 3; |
||||
} |
@ -0,0 +1,103 @@ |
||||
package ecode |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode/types" |
||||
"github.com/golang/protobuf/proto" |
||||
"github.com/golang/protobuf/ptypes" |
||||
) |
||||
|
||||
// Error new status with code and message
|
||||
func Error(code Code, message string) *Status { |
||||
return &Status{s: &types.Status{Code: int32(code.Code()), Message: message}} |
||||
} |
||||
|
||||
// Errorf new status with code and message
|
||||
func Errorf(code Code, format string, args ...interface{}) *Status { |
||||
return Error(code, fmt.Sprintf(format, args...)) |
||||
} |
||||
|
||||
var _ Codes = &Status{} |
||||
|
||||
// Status statusError is an alias of a status proto
|
||||
// implement ecode.Codes
|
||||
type Status struct { |
||||
s *types.Status |
||||
} |
||||
|
||||
// Error implement error
|
||||
func (s *Status) Error() string { |
||||
return s.Message() |
||||
} |
||||
|
||||
// Code return error code
|
||||
func (s *Status) Code() int { |
||||
return int(s.s.Code) |
||||
} |
||||
|
||||
// Message return error message for developer
|
||||
func (s *Status) Message() string { |
||||
if s.s.Message == "" { |
||||
return strconv.Itoa(int(s.s.Code)) |
||||
} |
||||
return s.s.Message |
||||
} |
||||
|
||||
// Details return error details
|
||||
func (s *Status) Details() []interface{} { |
||||
if s == nil || s.s == nil { |
||||
return nil |
||||
} |
||||
details := make([]interface{}, 0, len(s.s.Details)) |
||||
for _, any := range s.s.Details { |
||||
detail := &ptypes.DynamicAny{} |
||||
if err := ptypes.UnmarshalAny(any, detail); err != nil { |
||||
details = append(details, err) |
||||
continue |
||||
} |
||||
details = append(details, detail.Message) |
||||
} |
||||
return details |
||||
} |
||||
|
||||
// WithDetails WithDetails
|
||||
func (s *Status) WithDetails(pbs ...proto.Message) (*Status, error) { |
||||
for _, pb := range pbs { |
||||
anyMsg, err := ptypes.MarshalAny(pb) |
||||
if err != nil { |
||||
return s, err |
||||
} |
||||
s.s.Details = append(s.s.Details, anyMsg) |
||||
} |
||||
return s, nil |
||||
} |
||||
|
||||
// Equal for compatible.
|
||||
// Deprecated: please use ecode.EqualError.
|
||||
func (s *Status) Equal(err error) bool { |
||||
return EqualError(s, err) |
||||
} |
||||
|
||||
// Proto return origin protobuf message
|
||||
func (s *Status) Proto() *types.Status { |
||||
return s.s |
||||
} |
||||
|
||||
// FromCode create status from ecode
|
||||
func FromCode(code Code) *Status { |
||||
return &Status{s: &types.Status{Code: int32(code)}} |
||||
} |
||||
|
||||
// FromProto new status from grpc detail
|
||||
func FromProto(pbMsg proto.Message) Codes { |
||||
if msg, ok := pbMsg.(*types.Status); ok { |
||||
if msg.Message == "" { |
||||
// NOTE: if message is empty convert to pure Code, will get message from config center.
|
||||
return Code(msg.Code) |
||||
} |
||||
return &Status{s: msg} |
||||
} |
||||
return Errorf(ServerErr, "invalid proto message get %v", pbMsg) |
||||
} |
@ -0,0 +1,66 @@ |
||||
package ecode |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/golang/protobuf/ptypes/timestamp" |
||||
"github.com/smartystreets/goconvey/convey" |
||||
"github.com/stretchr/testify/assert" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode/types" |
||||
) |
||||
|
||||
func TestEqual(t *testing.T) { |
||||
convey.Convey("Equal", t, func(ctx convey.C) { |
||||
ctx.Convey("When err1=Error(RequestErr, 'test') and err2=Errorf(RequestErr, 'test')", func(ctx convey.C) { |
||||
err1 := Error(RequestErr, "test") |
||||
err2 := Errorf(RequestErr, "test") |
||||
ctx.Convey("Then err1=err2, err1 != nil", func(ctx convey.C) { |
||||
ctx.So(err1, convey.ShouldResemble, err2) |
||||
ctx.So(err1, convey.ShouldNotBeNil) |
||||
}) |
||||
}) |
||||
}) |
||||
// assert.True(t, OK.Equal(nil))
|
||||
// assert.True(t, err1.Equal(err2))
|
||||
// assert.False(t, err1.Equal(nil))
|
||||
// assert.True(t, Equal(nil, nil))
|
||||
} |
||||
|
||||
func TestDetail(t *testing.T) { |
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
st, _ := Error(RequestErr, "RequestErr").WithDetails(m) |
||||
|
||||
assert.Equal(t, "RequestErr", st.Message()) |
||||
assert.Equal(t, int(RequestErr), st.Code()) |
||||
assert.IsType(t, m, st.Details()[0]) |
||||
} |
||||
|
||||
func TestFromCode(t *testing.T) { |
||||
err := FromCode(RequestErr) |
||||
|
||||
assert.Equal(t, int(RequestErr), err.Code()) |
||||
assert.Equal(t, "-400", err.Message()) |
||||
} |
||||
|
||||
func TestFromProto(t *testing.T) { |
||||
msg := &types.Status{Code: 2233, Message: "error"} |
||||
err := FromProto(msg) |
||||
|
||||
assert.Equal(t, 2233, err.Code()) |
||||
assert.Equal(t, "error", err.Message()) |
||||
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
err = FromProto(m) |
||||
assert.Equal(t, -500, err.Code()) |
||||
assert.Contains(t, err.Message(), "invalid proto message get") |
||||
} |
||||
|
||||
func TestEmpty(t *testing.T) { |
||||
st := &Status{} |
||||
assert.Len(t, st.Details(), 0) |
||||
|
||||
st = nil |
||||
assert.Len(t, st.Details(), 0) |
||||
} |
@ -0,0 +1,102 @@ |
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: internal/types/status.proto
|
||||
|
||||
package types // import "github.com/bilibili/kratos/pkg/ecode/types"
|
||||
|
||||
import proto "github.com/golang/protobuf/proto" |
||||
import fmt "fmt" |
||||
import math "math" |
||||
import any "github.com/golang/protobuf/ptypes/any" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = fmt.Errorf |
||||
var _ = math.Inf |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type Status struct { |
||||
// The error code see ecode.Code
|
||||
Code int32 `protobuf:"varint,1,opt,name=code" json:"code,omitempty"` |
||||
// A developer-facing error message, which should be in English. Any
|
||||
Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` |
||||
// A list of messages that carry the error details. There is a common set of
|
||||
// message types for APIs to use.
|
||||
Details []*any.Any `protobuf:"bytes,3,rep,name=details" json:"details,omitempty"` |
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
XXX_sizecache int32 `json:"-"` |
||||
} |
||||
|
||||
func (m *Status) Reset() { *m = Status{} } |
||||
func (m *Status) String() string { return proto.CompactTextString(m) } |
||||
func (*Status) ProtoMessage() {} |
||||
func (*Status) Descriptor() ([]byte, []int) { |
||||
return fileDescriptor_status_88668d6b2bf80f08, []int{0} |
||||
} |
||||
func (m *Status) XXX_Unmarshal(b []byte) error { |
||||
return xxx_messageInfo_Status.Unmarshal(m, b) |
||||
} |
||||
func (m *Status) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { |
||||
return xxx_messageInfo_Status.Marshal(b, m, deterministic) |
||||
} |
||||
func (dst *Status) XXX_Merge(src proto.Message) { |
||||
xxx_messageInfo_Status.Merge(dst, src) |
||||
} |
||||
func (m *Status) XXX_Size() int { |
||||
return xxx_messageInfo_Status.Size(m) |
||||
} |
||||
func (m *Status) XXX_DiscardUnknown() { |
||||
xxx_messageInfo_Status.DiscardUnknown(m) |
||||
} |
||||
|
||||
var xxx_messageInfo_Status proto.InternalMessageInfo |
||||
|
||||
func (m *Status) GetCode() int32 { |
||||
if m != nil { |
||||
return m.Code |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *Status) GetMessage() string { |
||||
if m != nil { |
||||
return m.Message |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *Status) GetDetails() []*any.Any { |
||||
if m != nil { |
||||
return m.Details |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterType((*Status)(nil), "bilibili.rpc.Status") |
||||
} |
||||
|
||||
func init() { proto.RegisterFile("internal/types/status.proto", fileDescriptor_status_88668d6b2bf80f08) } |
||||
|
||||
var fileDescriptor_status_88668d6b2bf80f08 = []byte{ |
||||
// 220 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x8f, 0xb1, 0x4a, 0x04, 0x31, |
||||
0x10, 0x86, 0xd9, 0x5b, 0xbd, 0xc3, 0x9c, 0x85, 0x04, 0x8b, 0x55, 0x9b, 0xc5, 0x6a, 0x0b, 0x4d, |
||||
0x40, 0x4b, 0x2b, 0xcf, 0x17, 0x58, 0x22, 0x36, 0x76, 0x49, 0x6e, 0x2e, 0x04, 0x92, 0xcc, 0x92, |
||||
0xe4, 0x8a, 0xbc, 0x8e, 0x4f, 0x2a, 0x9b, 0x65, 0x41, 0x8b, 0x19, 0x66, 0x98, 0xff, 0xe7, 0xfb, |
||||
0x87, 0x3c, 0xd8, 0x90, 0x21, 0x06, 0xe9, 0x78, 0x2e, 0x13, 0x24, 0x9e, 0xb2, 0xcc, 0xe7, 0xc4, |
||||
0xa6, 0x88, 0x19, 0xe9, 0xb5, 0xb2, 0xce, 0xce, 0xc5, 0xe2, 0xa4, 0xef, 0xef, 0x0c, 0xa2, 0x71, |
||||
0xc0, 0xeb, 0x4d, 0x9d, 0x4f, 0x5c, 0x86, 0xb2, 0x08, 0x1f, 0x4f, 0x64, 0xfb, 0x59, 0x8d, 0x94, |
||||
0x92, 0x0b, 0x8d, 0x47, 0xe8, 0x9a, 0xbe, 0x19, 0x2e, 0x45, 0x9d, 0x69, 0x47, 0x76, 0x1e, 0x52, |
||||
0x92, 0x06, 0xba, 0x4d, 0xdf, 0x0c, 0x57, 0x62, 0x5d, 0x29, 0x23, 0xbb, 0x23, 0x64, 0x69, 0x5d, |
||||
0xea, 0xda, 0xbe, 0x1d, 0xf6, 0x2f, 0xb7, 0x6c, 0x81, 0xb0, 0x15, 0xc2, 0xde, 0x43, 0x11, 0xab, |
||||
0xe8, 0xf0, 0x45, 0x6e, 0x34, 0x7a, 0xf6, 0x37, 0xd6, 0x61, 0xbf, 0x90, 0xc7, 0xd9, 0x30, 0x36, |
||||
0xdf, 0x4f, 0x06, 0x9f, 0x35, 0x7a, 0x8f, 0x81, 0x3b, 0xab, 0xa2, 0x8c, 0x85, 0xc3, 0x9c, 0x82, |
||||
0xff, 0x7f, 0xf4, 0xad, 0xf6, 0x9f, 0x4d, 0x2b, 0xc6, 0x0f, 0xb5, 0xad, 0xb4, 0xd7, 0xdf, 0x00, |
||||
0x00, 0x00, 0xff, 0xff, 0x80, 0xa3, 0xc1, 0x82, 0x0d, 0x01, 0x00, 0x00, |
||||
} |
@ -0,0 +1,23 @@ |
||||
syntax = "proto3"; |
||||
|
||||
package bilibili.rpc; |
||||
|
||||
import "google/protobuf/any.proto"; |
||||
|
||||
option go_package = "github.com/bilibili/Kratos/pkg/ecode/types;types"; |
||||
option java_multiple_files = true; |
||||
option java_outer_classname = "StatusProto"; |
||||
option java_package = "com.bilibili.rpc"; |
||||
option objc_class_prefix = "RPC"; |
||||
|
||||
message Status { |
||||
// The error code see ecode.Code |
||||
int32 code = 1; |
||||
|
||||
// A developer-facing error message, which should be in English. Any |
||||
string message = 2; |
||||
|
||||
// A list of messages that carry the error details. There is a common set of |
||||
// message types for APIs to use. |
||||
repeated google.protobuf.Any details = 3; |
||||
} |
@ -0,0 +1,62 @@ |
||||
### net/rpc/warden |
||||
##### Version 1.1.12 |
||||
1. 设置 caller 为 no_user 如果 user 不存在 |
||||
|
||||
##### Version 1.1.12 |
||||
1. warden支持mirror传递 |
||||
|
||||
##### Version 1.1.11 |
||||
1. Validate RequestErr支持详细报错信息 |
||||
|
||||
##### Version 1.1.10 |
||||
1. 默认读取环境中的color |
||||
|
||||
##### Version 1.1.9 |
||||
1. 增加NonBlock模式 |
||||
|
||||
##### Version 1.1.8 |
||||
1. 新增appid mock |
||||
|
||||
##### Version 1.1.7 |
||||
1. 兼容cpu为0和wrr dt为0的情况 |
||||
|
||||
##### Version 1.1.6 |
||||
1. 修改caller传递和获取方式 |
||||
2. 添加error detail example |
||||
|
||||
##### Version 1.1.5 |
||||
1. 增加server端json格式支持 |
||||
|
||||
##### Version 1.1.4 |
||||
1. 判断reosvler.builder为nil之后再注册 |
||||
|
||||
##### Version 1.1.3 |
||||
1. 支持zone和clusters |
||||
|
||||
##### Version 1.1.2 |
||||
1. 业务错误日志记为 WARN |
||||
|
||||
##### Version 1.1.1 |
||||
1. server实现了返回cpu信息 |
||||
|
||||
##### Version 1.1.0 |
||||
1. 增加ErrorDetail |
||||
2. 修复日志打印error信息丢失问题 |
||||
|
||||
##### Version 1.0.3 |
||||
1. 给server增加keepalive参数 |
||||
|
||||
##### Version 1.0.2 |
||||
|
||||
1. 替代默认的timoue,使用durtaion.Shrink()来传递context |
||||
2. 修复peer.Addr为nil时会panic的问题 |
||||
|
||||
##### Version 1.0.1 |
||||
|
||||
1. 去除timeout的手动传递,改为使用grpc默认自带的grpc-timeout |
||||
2. 获取server address改为使用call option的方式,去除对balancer的依赖 |
||||
|
||||
##### Version 1.0.0 |
||||
|
||||
1. 使用NewClient来新建一个RPC客户端,并默认集成trace、log、recovery、moniter拦截器 |
||||
2. 使用NewServer来新建一个RPC服务端,并默认集成trace、log、recovery、moniter拦截器 |
@ -0,0 +1,10 @@ |
||||
# See the OWNERS docs at https://go.k8s.io/owners |
||||
|
||||
approvers: |
||||
- caoguoliang |
||||
- maojian |
||||
labels: |
||||
- library |
||||
reviewers: |
||||
- caoguoliang |
||||
- maojian |
@ -0,0 +1,13 @@ |
||||
#### net/rcp/warden |
||||
|
||||
##### 项目简介 |
||||
|
||||
来自 bilibili 主站技术部的 RPC 框架,融合主站技术部的核心科技,带来如飞一般的体验。 |
||||
|
||||
##### 编译环境 |
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行** |
||||
|
||||
##### 依赖包 |
||||
|
||||
- [grpc](google.golang.org/grpc) |
@ -0,0 +1,20 @@ |
||||
### business/warden/balancer/p2c |
||||
|
||||
### Version 1.3.1 |
||||
1. add more test |
||||
|
||||
### Version 1.3 |
||||
1. P2C替换smooth weighted round-robin |
||||
|
||||
##### Version 1.2.1 |
||||
1. 删除了netflix ribbon的权重算法,改成了平方根算法 |
||||
|
||||
##### Version 1.2.0 |
||||
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据) |
||||
|
||||
##### Version 1.1.0 |
||||
1. 实现了动态计算的调度轮询算法 |
||||
|
||||
##### Version 1.0.0 |
||||
|
||||
1. 实现了带权重可以识别Color的轮询算法 |
@ -0,0 +1,9 @@ |
||||
# See the OWNERS docs at https://go.k8s.io/owners |
||||
|
||||
approvers: |
||||
- caoguoliang |
||||
labels: |
||||
- library |
||||
reviewers: |
||||
- caoguoliang |
||||
- maojian |
@ -0,0 +1,13 @@ |
||||
#### business/warden/balancer/wrr |
||||
|
||||
##### 项目简介 |
||||
|
||||
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用 |
||||
|
||||
##### 编译环境 |
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行** |
||||
|
||||
##### 依赖包 |
||||
|
||||
- [grpc](google.golang.org/grpc) |
@ -0,0 +1,269 @@ |
||||
package p2c |
||||
|
||||
import ( |
||||
"context" |
||||
"math" |
||||
"math/rand" |
||||
"strconv" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
wmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
|
||||
"google.golang.org/grpc/balancer" |
||||
"google.golang.org/grpc/balancer/base" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/resolver" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
const ( |
||||
// The mean lifetime of `cost`, it reaches its half-life after Tau*ln(2).
|
||||
tau = int64(time.Millisecond * 600) |
||||
// if statistic not collected,we add a big penalty to endpoint
|
||||
penalty = uint64(1000 * time.Millisecond * 250) |
||||
|
||||
forceGap = int64(time.Second * 3) |
||||
) |
||||
|
||||
var _ base.PickerBuilder = &p2cPickerBuilder{} |
||||
var _ balancer.Picker = &p2cPicker{} |
||||
|
||||
// Name is the name of pick of two random choices balancer.
|
||||
const Name = "p2c" |
||||
|
||||
// newBuilder creates a new weighted-roundrobin balancer builder.
|
||||
func newBuilder() balancer.Builder { |
||||
return base.NewBalancerBuilder(Name, &p2cPickerBuilder{}) |
||||
} |
||||
|
||||
func init() { |
||||
balancer.Register(newBuilder()) |
||||
} |
||||
|
||||
type subConn struct { |
||||
// metadata
|
||||
conn balancer.SubConn |
||||
addr resolver.Address |
||||
meta wmd.MD |
||||
|
||||
//client statistic data
|
||||
lag uint64 |
||||
success uint64 |
||||
inflight int64 |
||||
// server statistic data
|
||||
svrCPU uint64 |
||||
|
||||
//last collected timestamp
|
||||
stamp int64 |
||||
//last pick timestamp
|
||||
pick int64 |
||||
// request number in a period time
|
||||
reqs int64 |
||||
} |
||||
|
||||
func (sc *subConn) health() uint64 { |
||||
return atomic.LoadUint64(&sc.success) |
||||
} |
||||
|
||||
func (sc *subConn) cost() uint64 { |
||||
load := atomic.LoadUint64(&sc.svrCPU) * atomic.LoadUint64(&sc.lag) * uint64(atomic.LoadInt64(&sc.inflight)) |
||||
if load == 0 { |
||||
// penalty是初始化没有数据时的惩罚值,默认为1e9 * 250
|
||||
load = penalty |
||||
} |
||||
return load |
||||
} |
||||
|
||||
// statistics is info for log
|
||||
type statistic struct { |
||||
addr string |
||||
score float64 |
||||
cs uint64 |
||||
lantency uint64 |
||||
cpu uint64 |
||||
inflight int64 |
||||
reqs int64 |
||||
} |
||||
|
||||
type p2cPickerBuilder struct{} |
||||
|
||||
func (*p2cPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker { |
||||
p := &p2cPicker{ |
||||
colors: make(map[string]*p2cPicker), |
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())), |
||||
} |
||||
for addr, sc := range readySCs { |
||||
meta, ok := addr.Metadata.(wmd.MD) |
||||
if !ok { |
||||
meta = wmd.MD{ |
||||
Weight: 10, |
||||
} |
||||
} |
||||
subc := &subConn{ |
||||
conn: sc, |
||||
addr: addr, |
||||
meta: meta, |
||||
|
||||
svrCPU: 500, |
||||
lag: 0, |
||||
success: 1000, |
||||
inflight: 1, |
||||
} |
||||
if meta.Color == "" { |
||||
p.subConns = append(p.subConns, subc) |
||||
continue |
||||
} |
||||
// if color not empty, use color picker
|
||||
cp, ok := p.colors[meta.Color] |
||||
if !ok { |
||||
cp = &p2cPicker{r: rand.New(rand.NewSource(time.Now().UnixNano()))} |
||||
p.colors[meta.Color] = cp |
||||
} |
||||
cp.subConns = append(cp.subConns, subc) |
||||
} |
||||
return p |
||||
} |
||||
|
||||
type p2cPicker struct { |
||||
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
|
||||
// created. The slice is immutable. Each Get() will do a round robin
|
||||
// selection from it and return the selected SubConn.
|
||||
subConns []*subConn |
||||
colors map[string]*p2cPicker |
||||
logTs int64 |
||||
r *rand.Rand |
||||
lk sync.Mutex |
||||
} |
||||
|
||||
func (p *p2cPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { |
||||
// FIXME refactor to unify the color logic
|
||||
color := nmd.String(ctx, nmd.Color) |
||||
if color == "" && env.Color != "" { |
||||
color = env.Color |
||||
} |
||||
if color != "" { |
||||
if cp, ok := p.colors[color]; ok { |
||||
return cp.pick(ctx, opts) |
||||
} |
||||
} |
||||
return p.pick(ctx, opts) |
||||
} |
||||
|
||||
func (p *p2cPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { |
||||
var pc, upc *subConn |
||||
start := time.Now().UnixNano() |
||||
|
||||
if len(p.subConns) <= 0 { |
||||
return nil, nil, balancer.ErrNoSubConnAvailable |
||||
} else if len(p.subConns) == 1 { |
||||
pc = p.subConns[0] |
||||
} else { |
||||
// choose two distinct nodes
|
||||
p.lk.Lock() |
||||
a := p.r.Intn(len(p.subConns)) |
||||
b := p.r.Intn(len(p.subConns) - 1) |
||||
p.lk.Unlock() |
||||
if b >= a { |
||||
b = b + 1 |
||||
} |
||||
nodeA, nodeB := p.subConns[a], p.subConns[b] |
||||
// meta.Weight为服务发布者在disocvery中设置的权重
|
||||
if nodeA.cost()*nodeB.health()*nodeB.meta.Weight > nodeB.cost()*nodeA.health()*nodeA.meta.Weight { |
||||
pc, upc = nodeB, nodeA |
||||
} else { |
||||
pc, upc = nodeA, nodeB |
||||
} |
||||
// 如果选中的节点,在forceGap期间内没有被选中一次,那么强制一次
|
||||
// 利用强制的机会,来触发成功率、延迟的衰减
|
||||
// 原子锁conn.pick保证并发安全,放行一次
|
||||
pick := atomic.LoadInt64(&upc.pick) |
||||
if start-pick > forceGap && atomic.CompareAndSwapInt64(&upc.pick, pick, start) { |
||||
pc = upc |
||||
} |
||||
} |
||||
|
||||
// 节点未发生切换才更新pick时间
|
||||
if pc != upc { |
||||
atomic.StoreInt64(&pc.pick, start) |
||||
} |
||||
atomic.AddInt64(&pc.inflight, 1) |
||||
atomic.AddInt64(&pc.reqs, 1) |
||||
return pc.conn, func(di balancer.DoneInfo) { |
||||
atomic.AddInt64(&pc.inflight, -1) |
||||
now := time.Now().UnixNano() |
||||
// get moving average ratio w
|
||||
stamp := atomic.SwapInt64(&pc.stamp, now) |
||||
td := now - stamp |
||||
if td < 0 { |
||||
td = 0 |
||||
} |
||||
w := math.Exp(float64(-td) / float64(tau)) |
||||
|
||||
lag := now - start |
||||
if lag < 0 { |
||||
lag = 0 |
||||
} |
||||
oldLag := atomic.LoadUint64(&pc.lag) |
||||
if oldLag == 0 { |
||||
w = 0.0 |
||||
} |
||||
lag = int64(float64(oldLag)*w + float64(lag)*(1.0-w)) |
||||
atomic.StoreUint64(&pc.lag, uint64(lag)) |
||||
|
||||
success := uint64(1000) // error value ,if error set 1
|
||||
if di.Err != nil { |
||||
if st, ok := status.FromError(di.Err); ok { |
||||
// only counter the local grpc error, ignore any business error
|
||||
if st.Code() != codes.Unknown && st.Code() != codes.OK { |
||||
success = 0 |
||||
} |
||||
} |
||||
} |
||||
oldSuc := atomic.LoadUint64(&pc.success) |
||||
success = uint64(float64(oldSuc)*w + float64(success)*(1.0-w)) |
||||
atomic.StoreUint64(&pc.success, success) |
||||
|
||||
trailer := di.Trailer |
||||
if strs, ok := trailer[wmd.CPUUsage]; ok { |
||||
if cpu, err2 := strconv.ParseUint(strs[0], 10, 64); err2 == nil && cpu > 0 { |
||||
atomic.StoreUint64(&pc.svrCPU, cpu) |
||||
} |
||||
} |
||||
|
||||
logTs := atomic.LoadInt64(&p.logTs) |
||||
if now-logTs > int64(time.Second*3) { |
||||
if atomic.CompareAndSwapInt64(&p.logTs, logTs, now) { |
||||
p.printStats() |
||||
} |
||||
} |
||||
}, nil |
||||
} |
||||
|
||||
func (p *p2cPicker) printStats() { |
||||
if len(p.subConns) <= 0 { |
||||
return |
||||
} |
||||
stats := make([]statistic, 0, len(p.subConns)) |
||||
for _, conn := range p.subConns { |
||||
var stat statistic |
||||
stat.addr = conn.addr.Addr |
||||
stat.cpu = atomic.LoadUint64(&conn.svrCPU) |
||||
stat.cs = atomic.LoadUint64(&conn.success) |
||||
stat.inflight = atomic.LoadInt64(&conn.inflight) |
||||
stat.lantency = atomic.LoadUint64(&conn.lag) |
||||
stat.reqs = atomic.SwapInt64(&conn.reqs, 0) |
||||
load := stat.cpu * uint64(stat.inflight) * stat.lantency |
||||
if load != 0 { |
||||
stat.score = float64(stat.cs*conn.meta.Weight*1e8) / float64(load) |
||||
} |
||||
stats = append(stats, stat) |
||||
} |
||||
log.Info("p2c %s : %+v", p.subConns[0].addr.ServerName, stats) |
||||
//fmt.Printf("%+v\n", stats)
|
||||
} |
@ -0,0 +1,347 @@ |
||||
package p2c |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
"math/rand" |
||||
"strconv" |
||||
"sync/atomic" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
|
||||
"google.golang.org/grpc/balancer" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/metadata" |
||||
"google.golang.org/grpc/resolver" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
var serverNum int |
||||
var cliNum int |
||||
var concurrency int |
||||
var extraLoad int64 |
||||
var extraDelay int64 |
||||
var extraWeight uint64 |
||||
|
||||
func init() { |
||||
flag.IntVar(&serverNum, "snum", 5, "-snum 6") |
||||
flag.IntVar(&cliNum, "cnum", 5, "-cnum 12") |
||||
flag.IntVar(&concurrency, "concurrency", 5, "-cc 10") |
||||
flag.Int64Var(&extraLoad, "exload", 3, "-exload 3") |
||||
flag.Int64Var(&extraDelay, "exdelay", 0, "-exdelay 250") |
||||
flag.Uint64Var(&extraWeight, "extraWeight", 0, "-exdelay 50") |
||||
} |
||||
|
||||
type testSubConn struct { |
||||
addr resolver.Address |
||||
wait chan struct{} |
||||
//statics
|
||||
reqs int64 |
||||
usage int64 |
||||
cpu int64 |
||||
prevReq int64 |
||||
prevUsage int64 |
||||
//control params
|
||||
loadJitter int64 |
||||
delayJitter int64 |
||||
} |
||||
|
||||
func newTestSubConn(addr string, weight uint64, color string) (sc *testSubConn) { |
||||
sc = &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: addr, |
||||
Metadata: wmeta.MD{ |
||||
Weight: weight, |
||||
Color: color, |
||||
}, |
||||
}, |
||||
wait: make(chan struct{}, 1000), |
||||
} |
||||
go func() { |
||||
for { |
||||
for i := 0; i < 210; i++ { |
||||
<-sc.wait |
||||
} |
||||
time.Sleep(time.Millisecond * 20) |
||||
} |
||||
}() |
||||
|
||||
return |
||||
} |
||||
|
||||
func (s *testSubConn) connect(ctx context.Context) { |
||||
time.Sleep(time.Millisecond * 15) |
||||
//add qps counter when request come in
|
||||
atomic.AddInt64(&s.reqs, 1) |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
case s.wait <- struct{}{}: |
||||
atomic.AddInt64(&s.usage, 1) |
||||
} |
||||
load := atomic.LoadInt64(&s.loadJitter) |
||||
if load > 0 { |
||||
for i := 0; i <= rand.Intn(int(load)); i++ { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
case s.wait <- struct{}{}: |
||||
atomic.AddInt64(&s.usage, 1) |
||||
} |
||||
} |
||||
} |
||||
delay := atomic.LoadInt64(&s.delayJitter) |
||||
if delay > 0 { |
||||
delay = rand.Int63n(delay) |
||||
time.Sleep(time.Millisecond * time.Duration(delay)) |
||||
} |
||||
} |
||||
|
||||
func (s *testSubConn) UpdateAddresses([]resolver.Address) { |
||||
|
||||
} |
||||
|
||||
// Connect starts the connecting for this SubConn.
|
||||
func (s *testSubConn) Connect() { |
||||
|
||||
} |
||||
|
||||
func TestBalancerPick(t *testing.T) { |
||||
scs := map[resolver.Address]balancer.SubConn{} |
||||
sc1 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test1", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 8, |
||||
}, |
||||
}, |
||||
} |
||||
sc2 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test2", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 4, |
||||
Color: "red", |
||||
}, |
||||
}, |
||||
} |
||||
sc3 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test3", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 2, |
||||
Color: "red", |
||||
}, |
||||
}, |
||||
} |
||||
sc4 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test4", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 2, |
||||
Color: "purple", |
||||
}, |
||||
}, |
||||
} |
||||
scs[sc1.addr] = sc1 |
||||
scs[sc2.addr] = sc2 |
||||
scs[sc3.addr] = sc3 |
||||
scs[sc4.addr] = sc4 |
||||
b := &p2cPickerBuilder{} |
||||
picker := b.Build(scs) |
||||
res := []string{"test1", "test1", "test1", "test1"} |
||||
for i := 0; i < 3; i++ { |
||||
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res[i] { |
||||
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i]) |
||||
} |
||||
} |
||||
|
||||
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"})) |
||||
for i := 0; i < 4; i++ { |
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res[i] { |
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i]) |
||||
} |
||||
} |
||||
|
||||
env.Color = "purple" |
||||
ctx2 := context.Background() |
||||
for i := 0; i < 4; i++ { |
||||
conn, _, err := picker.Pick(ctx2, balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != "test4" { |
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i]) |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
func Benchmark_Wrr(b *testing.B) { |
||||
scs := map[resolver.Address]balancer.SubConn{} |
||||
for i := 0; i < 50; i++ { |
||||
addr := resolver.Address{ |
||||
Addr: fmt.Sprintf("addr_%d", i), |
||||
Metadata: wmeta.MD{Weight: 10}, |
||||
} |
||||
scs[addr] = &testSubConn{addr: addr} |
||||
} |
||||
wpb := &p2cPickerBuilder{} |
||||
picker := wpb.Build(scs) |
||||
opt := balancer.PickOptions{} |
||||
ctx := context.Background() |
||||
for idx := 0; idx < b.N; idx++ { |
||||
_, done, err := picker.Pick(ctx, opt) |
||||
if err != nil { |
||||
done(balancer.DoneInfo{}) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestChaosPick(t *testing.T) { |
||||
flag.Parse() |
||||
fmt.Printf("start chaos test!svrNum:%d cliNum:%d concurrency:%d exLoad:%d exDelay:%d\n", serverNum, cliNum, concurrency, extraLoad, extraDelay) |
||||
c := newController(serverNum, cliNum) |
||||
c.launch(concurrency) |
||||
go c.updateStatics() |
||||
go c.control(extraLoad, extraDelay) |
||||
time.Sleep(time.Second * 50) |
||||
} |
||||
|
||||
func newController(svrNum int, cliNum int) *controller { |
||||
//new servers
|
||||
servers := []*testSubConn{} |
||||
var weight uint64 = 10 |
||||
if extraWeight > 0 { |
||||
weight = extraWeight |
||||
} |
||||
for i := 0; i < svrNum; i++ { |
||||
weight += extraWeight |
||||
sc := newTestSubConn(fmt.Sprintf("addr_%d", i), weight, "") |
||||
servers = append(servers, sc) |
||||
} |
||||
//new clients
|
||||
var clients []balancer.Picker |
||||
scs := map[resolver.Address]balancer.SubConn{} |
||||
for _, v := range servers { |
||||
scs[v.addr] = v |
||||
} |
||||
for i := 0; i < cliNum; i++ { |
||||
wpb := &p2cPickerBuilder{} |
||||
picker := wpb.Build(scs) |
||||
clients = append(clients, picker) |
||||
} |
||||
|
||||
c := &controller{ |
||||
servers: servers, |
||||
clients: clients, |
||||
} |
||||
return c |
||||
} |
||||
|
||||
type controller struct { |
||||
servers []*testSubConn |
||||
clients []balancer.Picker |
||||
} |
||||
|
||||
func (c *controller) launch(concurrency int) { |
||||
opt := balancer.PickOptions{} |
||||
bkg := context.Background() |
||||
for i := range c.clients { |
||||
for j := 0; j < concurrency; j++ { |
||||
picker := c.clients[i] |
||||
go func() { |
||||
for { |
||||
ctx, cancel := context.WithTimeout(bkg, time.Millisecond*250) |
||||
sc, done, _ := picker.Pick(ctx, opt) |
||||
server := sc.(*testSubConn) |
||||
server.connect(ctx) |
||||
var err error |
||||
if ctx.Err() != nil { |
||||
err = status.Errorf(codes.DeadlineExceeded, "dead") |
||||
} |
||||
cancel() |
||||
cpu := atomic.LoadInt64(&server.cpu) |
||||
md := make(map[string]string) |
||||
md[wmeta.CPUUsage] = strconv.FormatInt(cpu, 10) |
||||
done(balancer.DoneInfo{Trailer: metadata.New(md), Err: err}) |
||||
time.Sleep(time.Millisecond * 10) |
||||
} |
||||
}() |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (c *controller) updateStatics() { |
||||
for { |
||||
time.Sleep(time.Millisecond * 500) |
||||
for _, sc := range c.servers { |
||||
usage := atomic.LoadInt64(&sc.usage) |
||||
avgCpu := (usage - sc.prevUsage) * 2 |
||||
atomic.StoreInt64(&sc.cpu, avgCpu) |
||||
sc.prevUsage = usage |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
func (c *controller) control(extraLoad, extraDelay int64) { |
||||
var chaos int |
||||
for { |
||||
fmt.Printf("\n") |
||||
//make some chaos
|
||||
n := rand.Intn(3) |
||||
chaos = n + 1 |
||||
for i := 0; i < chaos; i++ { |
||||
if extraLoad > 0 { |
||||
degree := rand.Int63n(extraLoad) |
||||
degree++ |
||||
atomic.StoreInt64(&c.servers[i].loadJitter, degree) |
||||
fmt.Printf("set addr_%d load:%d ", i, degree) |
||||
} |
||||
if extraDelay > 0 { |
||||
degree := rand.Int63n(extraDelay) |
||||
atomic.StoreInt64(&c.servers[i].delayJitter, degree) |
||||
fmt.Printf("set addr_%d delay:%dms ", i, degree) |
||||
} |
||||
} |
||||
fmt.Printf("\n") |
||||
sleep := int64(5) |
||||
time.Sleep(time.Second * time.Duration(sleep)) |
||||
for _, sc := range c.servers { |
||||
req := atomic.LoadInt64(&sc.reqs) |
||||
qps := (req - sc.prevReq) / sleep |
||||
wait := len(sc.wait) |
||||
sc.prevReq = req |
||||
fmt.Printf("%s qps:%d waits:%d\n", sc.addr.Addr, qps, wait) |
||||
} |
||||
for _, picker := range c.clients { |
||||
p := picker.(*p2cPicker) |
||||
p.printStats() |
||||
} |
||||
fmt.Printf("\n") |
||||
//reset chaos
|
||||
for i := 0; i < chaos; i++ { |
||||
atomic.StoreInt64(&c.servers[i].loadJitter, 0) |
||||
atomic.StoreInt64(&c.servers[i].delayJitter, 0) |
||||
} |
||||
chaos = 0 |
||||
} |
||||
} |
@ -0,0 +1,17 @@ |
||||
### business/warden/balancer/wrr |
||||
|
||||
##### Version 1.3.0 |
||||
1. 迁移 stat.Summary 到 metric.RollingCounter,metric.RollingGauge |
||||
|
||||
##### Version 1.2.1 |
||||
1. 删除了netflix ribbon的权重算法,改成了平方根算法 |
||||
|
||||
##### Version 1.2.0 |
||||
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据) |
||||
|
||||
##### Version 1.1.0 |
||||
1. 实现了动态计算的调度轮询算法 |
||||
|
||||
##### Version 1.0.0 |
||||
|
||||
1. 实现了带权重可以识别Color的轮询算法 |
@ -0,0 +1,9 @@ |
||||
# See the OWNERS docs at https://go.k8s.io/owners |
||||
|
||||
approvers: |
||||
- caoguoliang |
||||
labels: |
||||
- library |
||||
reviewers: |
||||
- caoguoliang |
||||
- maojian |
@ -0,0 +1,13 @@ |
||||
#### business/warden/balancer/wrr |
||||
|
||||
##### 项目简介 |
||||
|
||||
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用 |
||||
|
||||
##### 编译环境 |
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行** |
||||
|
||||
##### 依赖包 |
||||
|
||||
- [grpc](google.golang.org/grpc) |
@ -0,0 +1,302 @@ |
||||
package wrr |
||||
|
||||
import ( |
||||
"context" |
||||
"math" |
||||
"strconv" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
"github.com/bilibili/kratos/pkg/stat/metric" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/balancer" |
||||
"google.golang.org/grpc/balancer/base" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/metadata" |
||||
"google.golang.org/grpc/resolver" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
var _ base.PickerBuilder = &wrrPickerBuilder{} |
||||
var _ balancer.Picker = &wrrPicker{} |
||||
|
||||
// var dwrrFeature feature.Feature = "dwrr"
|
||||
|
||||
// Name is the name of round_robin balancer.
|
||||
const Name = "wrr" |
||||
|
||||
// newBuilder creates a new weighted-roundrobin balancer builder.
|
||||
func newBuilder() balancer.Builder { |
||||
return base.NewBalancerBuilder(Name, &wrrPickerBuilder{}) |
||||
} |
||||
|
||||
func init() { |
||||
//feature.DefaultGate.Add(map[feature.Feature]feature.Spec{
|
||||
// dwrrFeature: {Default: false},
|
||||
//})
|
||||
|
||||
balancer.Register(newBuilder()) |
||||
} |
||||
|
||||
type serverInfo struct { |
||||
cpu int64 |
||||
success uint64 // float64 bits
|
||||
} |
||||
|
||||
type subConn struct { |
||||
conn balancer.SubConn |
||||
addr resolver.Address |
||||
meta wmeta.MD |
||||
|
||||
err metric.RollingCounter |
||||
latency metric.RollingGauge |
||||
si serverInfo |
||||
// effective weight
|
||||
ewt int64 |
||||
// current weight
|
||||
cwt int64 |
||||
// last score
|
||||
score float64 |
||||
} |
||||
|
||||
func (c *subConn) errSummary() (err int64, req int64) { |
||||
c.err.Reduce(func(iterator metric.Iterator) float64 { |
||||
for iterator.Next() { |
||||
bucket := iterator.Bucket() |
||||
req += bucket.Count |
||||
for _, p := range bucket.Points { |
||||
err += int64(p) |
||||
} |
||||
} |
||||
return 0 |
||||
}) |
||||
return |
||||
} |
||||
|
||||
func (c *subConn) latencySummary() (latency float64, count int64) { |
||||
c.latency.Reduce(func(iterator metric.Iterator) float64 { |
||||
for iterator.Next() { |
||||
bucket := iterator.Bucket() |
||||
count += bucket.Count |
||||
for _, p := range bucket.Points { |
||||
latency += p |
||||
} |
||||
} |
||||
return 0 |
||||
}) |
||||
return latency / float64(count), count |
||||
} |
||||
|
||||
// statistics is info for log
|
||||
type statistics struct { |
||||
addr string |
||||
ewt int64 |
||||
cs float64 |
||||
ss float64 |
||||
lantency float64 |
||||
cpu float64 |
||||
req int64 |
||||
} |
||||
|
||||
// Stats is grpc Interceptor for client to collect server stats
|
||||
func Stats() grpc.UnaryClientInterceptor { |
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { |
||||
var ( |
||||
trailer metadata.MD |
||||
md nmd.MD |
||||
ok bool |
||||
) |
||||
if md, ok = nmd.FromContext(ctx); !ok { |
||||
md = nmd.MD{} |
||||
} else { |
||||
md = md.Copy() |
||||
} |
||||
ctx = nmd.NewContext(ctx, md) |
||||
opts = append(opts, grpc.Trailer(&trailer)) |
||||
|
||||
err = invoker(ctx, method, req, reply, cc, opts...) |
||||
|
||||
conn, ok := md["conn"].(*subConn) |
||||
if !ok { |
||||
return |
||||
} |
||||
if strs, ok := trailer[wmeta.CPUUsage]; ok { |
||||
if cpu, err2 := strconv.ParseInt(strs[0], 10, 64); err2 == nil && cpu > 0 { |
||||
atomic.StoreInt64(&conn.si.cpu, cpu) |
||||
} |
||||
} |
||||
return |
||||
} |
||||
} |
||||
|
||||
type wrrPickerBuilder struct{} |
||||
|
||||
func (*wrrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker { |
||||
p := &wrrPicker{ |
||||
colors: make(map[string]*wrrPicker), |
||||
} |
||||
for addr, sc := range readySCs { |
||||
meta, ok := addr.Metadata.(wmeta.MD) |
||||
if !ok { |
||||
meta = wmeta.MD{ |
||||
Weight: 10, |
||||
} |
||||
} |
||||
subc := &subConn{ |
||||
conn: sc, |
||||
addr: addr, |
||||
|
||||
meta: meta, |
||||
ewt: int64(meta.Weight), |
||||
score: -1, |
||||
|
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
|
||||
si: serverInfo{cpu: 500, success: math.Float64bits(1)}, |
||||
} |
||||
if meta.Color == "" { |
||||
p.subConns = append(p.subConns, subc) |
||||
continue |
||||
} |
||||
// if color not empty, use color picker
|
||||
cp, ok := p.colors[meta.Color] |
||||
if !ok { |
||||
cp = &wrrPicker{} |
||||
p.colors[meta.Color] = cp |
||||
} |
||||
cp.subConns = append(cp.subConns, subc) |
||||
} |
||||
return p |
||||
} |
||||
|
||||
type wrrPicker struct { |
||||
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
|
||||
// created. The slice is immutable. Each Get() will do a round robin
|
||||
// selection from it and return the selected SubConn.
|
||||
subConns []*subConn |
||||
colors map[string]*wrrPicker |
||||
updateAt int64 |
||||
|
||||
mu sync.Mutex |
||||
} |
||||
|
||||
func (p *wrrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { |
||||
// FIXME refactor to unify the color logic
|
||||
color := nmd.String(ctx, nmd.Color) |
||||
if color == "" && env.Color != "" { |
||||
color = env.Color |
||||
} |
||||
if color != "" { |
||||
if cp, ok := p.colors[color]; ok { |
||||
return cp.pick(ctx, opts) |
||||
} |
||||
} |
||||
return p.pick(ctx, opts) |
||||
} |
||||
|
||||
func (p *wrrPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { |
||||
var ( |
||||
conn *subConn |
||||
totalWeight int64 |
||||
) |
||||
if len(p.subConns) <= 0 { |
||||
return nil, nil, balancer.ErrNoSubConnAvailable |
||||
} |
||||
p.mu.Lock() |
||||
// nginx wrr load balancing algorithm: http://blog.csdn.net/zhangskd/article/details/50194069
|
||||
for _, sc := range p.subConns { |
||||
totalWeight += sc.ewt |
||||
sc.cwt += sc.ewt |
||||
if conn == nil || conn.cwt < sc.cwt { |
||||
conn = sc |
||||
} |
||||
} |
||||
conn.cwt -= totalWeight |
||||
p.mu.Unlock() |
||||
start := time.Now() |
||||
if cmd, ok := nmd.FromContext(ctx); ok { |
||||
cmd["conn"] = conn |
||||
} |
||||
//if !feature.DefaultGate.Enabled(dwrrFeature) {
|
||||
// return conn.conn, nil, nil
|
||||
//}
|
||||
return conn.conn, func(di balancer.DoneInfo) { |
||||
ev := int64(0) // error value ,if error set 1
|
||||
if di.Err != nil { |
||||
if st, ok := status.FromError(di.Err); ok { |
||||
// only counter the local grpc error, ignore any business error
|
||||
if st.Code() != codes.Unknown && st.Code() != codes.OK { |
||||
ev = 1 |
||||
} |
||||
} |
||||
} |
||||
conn.err.Add(ev) |
||||
|
||||
now := time.Now() |
||||
conn.latency.Add(now.Sub(start).Nanoseconds() / 1e5) |
||||
u := atomic.LoadInt64(&p.updateAt) |
||||
if now.UnixNano()-u < int64(time.Second) { |
||||
return |
||||
} |
||||
if !atomic.CompareAndSwapInt64(&p.updateAt, u, now.UnixNano()) { |
||||
return |
||||
} |
||||
var ( |
||||
stats = make([]statistics, len(p.subConns)) |
||||
count int |
||||
total float64 |
||||
) |
||||
for i, conn := range p.subConns { |
||||
cpu := float64(atomic.LoadInt64(&conn.si.cpu)) |
||||
ss := math.Float64frombits(atomic.LoadUint64(&conn.si.success)) |
||||
errc, req := conn.errSummary() |
||||
lagv, lagc := conn.latencySummary() |
||||
|
||||
if req > 0 && lagc > 0 && lagv > 0 { |
||||
// client-side success ratio
|
||||
cs := 1 - (float64(errc) / float64(req)) |
||||
if cs <= 0 { |
||||
cs = 0.1 |
||||
} else if cs <= 0.2 && req <= 5 { |
||||
cs = 0.2 |
||||
} |
||||
conn.score = math.Sqrt((cs * ss * ss * 1e9) / (lagv * cpu)) |
||||
stats[i] = statistics{cs: cs, ss: ss, lantency: lagv, cpu: cpu, req: req} |
||||
} |
||||
stats[i].addr = conn.addr.Addr |
||||
|
||||
if conn.score > 0 { |
||||
total += conn.score |
||||
count++ |
||||
} |
||||
} |
||||
// count must be greater than 1,otherwise will lead ewt to 0
|
||||
if count < 2 { |
||||
return |
||||
} |
||||
avgscore := total / float64(count) |
||||
p.mu.Lock() |
||||
for i, conn := range p.subConns { |
||||
if conn.score <= 0 { |
||||
conn.score = avgscore |
||||
} |
||||
conn.ewt = int64(conn.score * float64(conn.meta.Weight)) |
||||
stats[i].ewt = conn.ewt |
||||
} |
||||
p.mu.Unlock() |
||||
log.Info("warden wrr(%s): %+v", conn.addr.ServerName, stats) |
||||
}, nil |
||||
|
||||
} |
@ -0,0 +1,189 @@ |
||||
package wrr |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
"github.com/bilibili/kratos/pkg/stat/metric" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"google.golang.org/grpc/balancer" |
||||
"google.golang.org/grpc/resolver" |
||||
) |
||||
|
||||
type testSubConn struct { |
||||
addr resolver.Address |
||||
} |
||||
|
||||
func (s *testSubConn) UpdateAddresses([]resolver.Address) { |
||||
|
||||
} |
||||
|
||||
// Connect starts the connecting for this SubConn.
|
||||
func (s *testSubConn) Connect() { |
||||
fmt.Println(s.addr.Addr) |
||||
} |
||||
|
||||
func TestBalancerPick(t *testing.T) { |
||||
scs := map[resolver.Address]balancer.SubConn{} |
||||
sc1 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test1", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 8, |
||||
}, |
||||
}, |
||||
} |
||||
sc2 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test2", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 4, |
||||
Color: "red", |
||||
}, |
||||
}, |
||||
} |
||||
sc3 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test3", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 2, |
||||
Color: "red", |
||||
}, |
||||
}, |
||||
} |
||||
scs[sc1.addr] = sc1 |
||||
scs[sc2.addr] = sc2 |
||||
scs[sc3.addr] = sc3 |
||||
b := &wrrPickerBuilder{} |
||||
picker := b.Build(scs) |
||||
res := []string{"test1", "test1", "test1", "test1"} |
||||
for i := 0; i < 3; i++ { |
||||
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res[i] { |
||||
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i]) |
||||
} |
||||
} |
||||
res2 := []string{"test2", "test3", "test2", "test2", "test3", "test2"} |
||||
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "red"})) |
||||
for i := 0; i < 6; i++ { |
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res2[i] { |
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i]) |
||||
} |
||||
} |
||||
ctx = nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"})) |
||||
for i := 0; i < 4; i++ { |
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res[i] { |
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i]) |
||||
} |
||||
} |
||||
|
||||
// test for env color
|
||||
ctx = context.Background() |
||||
env.Color = "red" |
||||
for i := 0; i < 6; i++ { |
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{}) |
||||
if err != nil { |
||||
t.Fatalf("picker.Pick failed!idx:=%d", i) |
||||
} |
||||
sc := conn.(*testSubConn) |
||||
if sc.addr.Addr != res2[i] { |
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestBalancerDone(t *testing.T) { |
||||
scs := map[resolver.Address]balancer.SubConn{} |
||||
sc1 := &testSubConn{ |
||||
addr: resolver.Address{ |
||||
Addr: "test1", |
||||
Metadata: wmeta.MD{ |
||||
Weight: 8, |
||||
}, |
||||
}, |
||||
} |
||||
scs[sc1.addr] = sc1 |
||||
b := &wrrPickerBuilder{} |
||||
picker := b.Build(scs) |
||||
|
||||
_, done, _ := picker.Pick(context.Background(), balancer.PickOptions{}) |
||||
time.Sleep(100 * time.Millisecond) |
||||
done(balancer.DoneInfo{Err: status.Errorf(codes.Unknown, "test")}) |
||||
err, req := picker.(*wrrPicker).subConns[0].errSummary() |
||||
assert.Equal(t, int64(0), err) |
||||
assert.Equal(t, int64(1), req) |
||||
|
||||
latency, count := picker.(*wrrPicker).subConns[0].latencySummary() |
||||
expectLatency := float64(100*time.Millisecond) / 1e5 |
||||
if !(expectLatency < latency && latency < (expectLatency+100)) { |
||||
t.Fatalf("latency is less than 100ms or greter than 100ms, %f", latency) |
||||
} |
||||
assert.Equal(t, int64(1), count) |
||||
|
||||
_, done, _ = picker.Pick(context.Background(), balancer.PickOptions{}) |
||||
done(balancer.DoneInfo{Err: status.Errorf(codes.Aborted, "test")}) |
||||
err, req = picker.(*wrrPicker).subConns[0].errSummary() |
||||
assert.Equal(t, int64(1), err) |
||||
assert.Equal(t, int64(2), req) |
||||
} |
||||
|
||||
func TestErrSummary(t *testing.T) { |
||||
sc := &subConn{ |
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
} |
||||
for i := 0; i < 10; i++ { |
||||
sc.err.Add(0) |
||||
sc.err.Add(1) |
||||
} |
||||
err, req := sc.errSummary() |
||||
assert.Equal(t, int64(10), err) |
||||
assert.Equal(t, int64(20), req) |
||||
} |
||||
|
||||
func TestLatencySummary(t *testing.T) { |
||||
sc := &subConn{ |
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{ |
||||
Size: 10, |
||||
BucketDuration: time.Millisecond * 100, |
||||
}), |
||||
} |
||||
for i := 1; i <= 100; i++ { |
||||
sc.latency.Add(int64(i)) |
||||
} |
||||
latency, count := sc.latencySummary() |
||||
assert.Equal(t, 50.50, latency) |
||||
assert.Equal(t, int64(100), count) |
||||
} |
@ -0,0 +1,334 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/url" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/conf/flagvar" |
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
"github.com/bilibili/kratos/pkg/naming/discovery" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/balancer/p2c" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/pkg/errors" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/credentials" |
||||
"google.golang.org/grpc/metadata" |
||||
"google.golang.org/grpc/peer" |
||||
gstatus "google.golang.org/grpc/status" |
||||
) |
||||
|
||||
var _grpcTarget flagvar.StringVars |
||||
|
||||
var ( |
||||
_once sync.Once |
||||
_defaultCliConf = &ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Millisecond * 250), |
||||
Subset: 50, |
||||
} |
||||
_defaultClient *Client |
||||
) |
||||
|
||||
func baseMetadata() metadata.MD { |
||||
gmd := metadata.MD{nmd.Caller: []string{env.AppID}} |
||||
if env.Color != "" { |
||||
gmd[nmd.Color] = []string{env.Color} |
||||
} |
||||
return gmd |
||||
} |
||||
|
||||
// ClientConfig is rpc client conf.
|
||||
type ClientConfig struct { |
||||
Dial xtime.Duration |
||||
Timeout xtime.Duration |
||||
Breaker *breaker.Config |
||||
Method map[string]*ClientConfig |
||||
Clusters []string |
||||
Zone string |
||||
Subset int |
||||
NonBlock bool |
||||
} |
||||
|
||||
// Client is the framework's client side instance, it contains the ctx, opt and interceptors.
|
||||
// Create an instance of Client, by using NewClient().
|
||||
type Client struct { |
||||
conf *ClientConfig |
||||
breaker *breaker.Group |
||||
mutex sync.RWMutex |
||||
|
||||
opt []grpc.DialOption |
||||
handlers []grpc.UnaryClientInterceptor |
||||
} |
||||
|
||||
// handle returns a new unary client interceptor for OpenTracing\Logging\LinkTimeout.
|
||||
func (c *Client) handle() grpc.UnaryClientInterceptor { |
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { |
||||
var ( |
||||
ok bool |
||||
cmd nmd.MD |
||||
t trace.Trace |
||||
gmd metadata.MD |
||||
conf *ClientConfig |
||||
cancel context.CancelFunc |
||||
addr string |
||||
p peer.Peer |
||||
) |
||||
var ec ecode.Codes = ecode.OK |
||||
// apm tracing
|
||||
if t, ok = trace.FromContext(ctx); ok { |
||||
t = t.Fork("", method) |
||||
defer t.Finish(&err) |
||||
} |
||||
|
||||
// setup metadata
|
||||
gmd = baseMetadata() |
||||
trace.Inject(t, trace.GRPCFormat, gmd) |
||||
c.mutex.RLock() |
||||
if conf, ok = c.conf.Method[method]; !ok { |
||||
conf = c.conf |
||||
} |
||||
c.mutex.RUnlock() |
||||
brk := c.breaker.Get(method) |
||||
if err = brk.Allow(); err != nil { |
||||
statsClient.Incr(method, "breaker") |
||||
return |
||||
} |
||||
defer onBreaker(brk, &err) |
||||
_, ctx, cancel = conf.Timeout.Shrink(ctx) |
||||
defer cancel() |
||||
if cmd, ok = nmd.FromContext(ctx); ok { |
||||
for netKey, val := range cmd { |
||||
if !nmd.IsOutgoingKey(netKey) { |
||||
continue |
||||
} |
||||
valstr, ok := val.(string) |
||||
if ok { |
||||
gmd[netKey] = []string{valstr} |
||||
} |
||||
} |
||||
} |
||||
// merge with old matadata if exists
|
||||
if oldmd, ok := metadata.FromOutgoingContext(ctx); ok { |
||||
gmd = metadata.Join(gmd, oldmd) |
||||
} |
||||
ctx = metadata.NewOutgoingContext(ctx, gmd) |
||||
|
||||
opts = append(opts, grpc.Peer(&p)) |
||||
if err = invoker(ctx, method, req, reply, cc, opts...); err != nil { |
||||
gst, _ := gstatus.FromError(err) |
||||
ec = status.ToEcode(gst) |
||||
err = errors.WithMessage(ec, gst.Message()) |
||||
} |
||||
if p.Addr != nil { |
||||
addr = p.Addr.String() |
||||
} |
||||
if t != nil { |
||||
t.SetTag(trace.String(trace.TagAddress, addr), trace.String(trace.TagComment, "")) |
||||
} |
||||
return |
||||
} |
||||
} |
||||
|
||||
func onBreaker(breaker breaker.Breaker, err *error) { |
||||
if err != nil && *err != nil { |
||||
if ecode.ServerErr.Equal(*err) || ecode.ServiceUnavailable.Equal(*err) || ecode.Deadline.Equal(*err) || ecode.LimitExceed.Equal(*err) { |
||||
breaker.MarkFailed() |
||||
return |
||||
} |
||||
} |
||||
breaker.MarkSuccess() |
||||
} |
||||
|
||||
// NewConn will create a grpc conn by default config.
|
||||
func NewConn(target string, opt ...grpc.DialOption) (*grpc.ClientConn, error) { |
||||
return DefaultClient().Dial(context.Background(), target, opt...) |
||||
} |
||||
|
||||
// NewClient returns a new blank Client instance with a default client interceptor.
|
||||
// opt can be used to add grpc dial options.
|
||||
func NewClient(conf *ClientConfig, opt ...grpc.DialOption) *Client { |
||||
resolver.Register(discovery.Builder()) |
||||
c := new(Client) |
||||
if err := c.SetConfig(conf); err != nil { |
||||
panic(err) |
||||
} |
||||
c.UseOpt(grpc.WithBalancerName(p2c.Name)) |
||||
c.UseOpt(opt...) |
||||
c.Use(c.recovery(), clientLogging(), c.handle()) |
||||
return c |
||||
} |
||||
|
||||
// DefaultClient returns a new default Client instance with a default client interceptor and default dialoption.
|
||||
// opt can be used to add grpc dial options.
|
||||
func DefaultClient() *Client { |
||||
resolver.Register(discovery.Builder()) |
||||
_once.Do(func() { |
||||
_defaultClient = NewClient(nil) |
||||
}) |
||||
return _defaultClient |
||||
} |
||||
|
||||
// SetConfig hot reloads client config
|
||||
func (c *Client) SetConfig(conf *ClientConfig) (err error) { |
||||
if conf == nil { |
||||
conf = _defaultCliConf |
||||
} |
||||
if conf.Dial <= 0 { |
||||
conf.Dial = xtime.Duration(time.Second * 10) |
||||
} |
||||
if conf.Timeout <= 0 { |
||||
conf.Timeout = xtime.Duration(time.Millisecond * 250) |
||||
} |
||||
if conf.Subset <= 0 { |
||||
conf.Subset = 50 |
||||
} |
||||
|
||||
// FIXME(maojian) check Method dial/timeout
|
||||
c.mutex.Lock() |
||||
c.conf = conf |
||||
if c.breaker == nil { |
||||
c.breaker = breaker.NewGroup(conf.Breaker) |
||||
} else { |
||||
c.breaker.Reload(conf.Breaker) |
||||
} |
||||
c.mutex.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
// Use attachs a global inteceptor to the Client.
|
||||
// For example, this is the right place for a circuit breaker or error management inteceptor.
|
||||
func (c *Client) Use(handlers ...grpc.UnaryClientInterceptor) *Client { |
||||
finalSize := len(c.handlers) + len(handlers) |
||||
if finalSize >= int(_abortIndex) { |
||||
panic("warden: client use too many handlers") |
||||
} |
||||
mergedHandlers := make([]grpc.UnaryClientInterceptor, finalSize) |
||||
copy(mergedHandlers, c.handlers) |
||||
copy(mergedHandlers[len(c.handlers):], handlers) |
||||
c.handlers = mergedHandlers |
||||
return c |
||||
} |
||||
|
||||
// UseOpt attachs a global grpc DialOption to the Client.
|
||||
func (c *Client) UseOpt(opt ...grpc.DialOption) *Client { |
||||
c.opt = append(c.opt, opt...) |
||||
return c |
||||
} |
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
// Target format is scheme://authority/endpoint?query_arg=value
|
||||
// example: discovery://default/account.account.service?cluster=shfy01&cluster=shfy02
|
||||
func (c *Client) Dial(ctx context.Context, target string, opt ...grpc.DialOption) (conn *grpc.ClientConn, err error) { |
||||
if !c.conf.NonBlock { |
||||
c.opt = append(c.opt, grpc.WithBlock()) |
||||
} |
||||
c.opt = append(c.opt, grpc.WithInsecure()) |
||||
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient())) |
||||
c.opt = append(c.opt, opt...) |
||||
c.mutex.RLock() |
||||
conf := c.conf |
||||
c.mutex.RUnlock() |
||||
if conf.Dial > 0 { |
||||
var cancel context.CancelFunc |
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial)) |
||||
defer cancel() |
||||
} |
||||
if u, e := url.Parse(target); e == nil { |
||||
v := u.Query() |
||||
for _, c := range c.conf.Clusters { |
||||
v.Add(naming.MetaCluster, c) |
||||
} |
||||
if c.conf.Zone != "" { |
||||
v.Add(naming.MetaZone, c.conf.Zone) |
||||
} |
||||
if v.Get("subset") == "" && c.conf.Subset > 0 { |
||||
v.Add("subset", strconv.FormatInt(int64(c.conf.Subset), 10)) |
||||
} |
||||
u.RawQuery = v.Encode() |
||||
// 比较_grpcTarget中的appid是否等于u.path中的appid,并替换成mock的地址
|
||||
for _, t := range _grpcTarget { |
||||
strs := strings.SplitN(t, "=", 2) |
||||
if len(strs) == 2 && ("/"+strs[0]) == u.Path { |
||||
u.Path = "/" + strs[1] |
||||
u.Scheme = "passthrough" |
||||
u.RawQuery = "" |
||||
break |
||||
} |
||||
} |
||||
target = u.String() |
||||
} |
||||
if conn, err = grpc.DialContext(ctx, target, c.opt...); err != nil { |
||||
fmt.Fprintf(os.Stderr, "warden client: dial %s error %v!", target, err) |
||||
} |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
// DialTLS creates a client connection over tls transport to the given target.
|
||||
func (c *Client) DialTLS(ctx context.Context, target string, file string, name string) (conn *grpc.ClientConn, err error) { |
||||
var creds credentials.TransportCredentials |
||||
creds, err = credentials.NewClientTLSFromFile(file, name) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
c.opt = append(c.opt, grpc.WithBlock()) |
||||
c.opt = append(c.opt, grpc.WithTransportCredentials(creds)) |
||||
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient())) |
||||
c.mutex.RLock() |
||||
conf := c.conf |
||||
c.mutex.RUnlock() |
||||
if conf.Dial > 0 { |
||||
var cancel context.CancelFunc |
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial)) |
||||
defer cancel() |
||||
} |
||||
conn, err = grpc.DialContext(ctx, target, c.opt...) |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
// chainUnaryClient creates a single interceptor out of a chain of many interceptors.
|
||||
//
|
||||
// Execution is done in left-to-right order, including passing of context.
|
||||
// For example ChainUnaryClient(one, two, three) will execute one before two before three.
|
||||
func (c *Client) chainUnaryClient() grpc.UnaryClientInterceptor { |
||||
n := len(c.handlers) |
||||
if n == 0 { |
||||
return func(ctx context.Context, method string, req, reply interface{}, |
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { |
||||
return invoker(ctx, method, req, reply, cc, opts...) |
||||
} |
||||
} |
||||
|
||||
return func(ctx context.Context, method string, req, reply interface{}, |
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { |
||||
var ( |
||||
i int |
||||
chainHandler grpc.UnaryInvoker |
||||
) |
||||
chainHandler = func(ictx context.Context, imethod string, ireq, ireply interface{}, ic *grpc.ClientConn, iopts ...grpc.CallOption) error { |
||||
if i == n-1 { |
||||
return invoker(ictx, imethod, ireq, ireply, ic, iopts...) |
||||
} |
||||
i++ |
||||
return c.handlers[i](ictx, imethod, ireq, ireply, ic, chainHandler, iopts...) |
||||
} |
||||
|
||||
return c.handlers[0](ctx, method, req, reply, cc, chainHandler, opts...) |
||||
} |
||||
} |
@ -0,0 +1,91 @@ |
||||
package warden_test |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"google.golang.org/grpc" |
||||
) |
||||
|
||||
type helloServer struct { |
||||
} |
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil |
||||
} |
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
for i := 0; i < 3; i++ { |
||||
in, err := ss.Recv() |
||||
if err == io.EOF { |
||||
return nil |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true} |
||||
err = ss.Send(ret) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
|
||||
} |
||||
|
||||
func ExampleServer() { |
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second), Addr: ":8080"}) |
||||
// apply server interceptor middleware
|
||||
s.Use(func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
newctx, cancel := context.WithTimeout(ctx, time.Second*10) |
||||
defer cancel() |
||||
resp, err := handler(newctx, req) |
||||
return resp, err |
||||
}) |
||||
pb.RegisterGreeterServer(s.Server(), &helloServer{}) |
||||
s.Start() |
||||
} |
||||
|
||||
func ExampleClient() { |
||||
client := warden.NewClient(&warden.ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
}) |
||||
// apply client interceptor middleware
|
||||
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) { |
||||
newctx, cancel := context.WithTimeout(ctx, time.Second*5) |
||||
defer cancel() |
||||
ret = invoker(newctx, method, req, reply, cc, opts...) |
||||
return |
||||
}) |
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080") |
||||
if err != nil { |
||||
log.Error("did not connect: %v", err) |
||||
return |
||||
} |
||||
defer conn.Close() |
||||
|
||||
c := pb.NewGreeterClient(conn) |
||||
name := "2233" |
||||
rp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name, Age: 18}) |
||||
if err != nil { |
||||
log.Error("could not greet: %v", err) |
||||
return |
||||
} |
||||
fmt.Println("rp", *rp) |
||||
} |
@ -0,0 +1,189 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"flag" |
||||
"log" |
||||
"reflect" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
goproto "github.com/gogo/protobuf/proto" |
||||
"github.com/montanaflynn/stats" |
||||
"golang.org/x/net/context" |
||||
"google.golang.org/grpc" |
||||
) |
||||
|
||||
const ( |
||||
iws = 65535 * 1000 |
||||
iwsc = 65535 * 10000 |
||||
readBuffer = 32 * 1024 |
||||
writeBuffer = 32 * 1024 |
||||
) |
||||
|
||||
var concurrency = flag.Int("c", 50, "concurrency") |
||||
var total = flag.Int("t", 500000, "total requests for all clients") |
||||
var host = flag.String("s", "127.0.0.1:8972", "server ip and port") |
||||
var isWarden = flag.Bool("w", true, "is warden or grpc client") |
||||
var strLen = flag.Int("l", 600, "the length of the str") |
||||
|
||||
func wardenCli() proto.HelloClient { |
||||
log.Println("start warden cli") |
||||
client := warden.NewClient(&warden.ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
}, |
||||
grpc.WithInitialWindowSize(iws), |
||||
grpc.WithInitialConnWindowSize(iwsc), |
||||
grpc.WithReadBufferSize(readBuffer), |
||||
grpc.WithWriteBufferSize(writeBuffer)) |
||||
conn, err := client.Dial(context.Background(), *host) |
||||
if err != nil { |
||||
log.Fatalf("did not connect: %v", err) |
||||
} |
||||
cli := proto.NewHelloClient(conn) |
||||
return cli |
||||
} |
||||
|
||||
func grpcCli() proto.HelloClient { |
||||
log.Println("start grpc cli") |
||||
conn, err := grpc.Dial(*host, grpc.WithInsecure(), |
||||
grpc.WithInitialWindowSize(iws), |
||||
grpc.WithInitialConnWindowSize(iwsc), |
||||
grpc.WithReadBufferSize(readBuffer), |
||||
grpc.WithWriteBufferSize(writeBuffer)) |
||||
if err != nil { |
||||
log.Fatalf("did not connect: %v", err) |
||||
} |
||||
cli := proto.NewHelloClient(conn) |
||||
return cli |
||||
} |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
c := *concurrency |
||||
m := *total / c |
||||
var wg sync.WaitGroup |
||||
wg.Add(c) |
||||
log.Printf("concurrency: %d\nrequests per client: %d\n\n", c, m) |
||||
|
||||
args := prepareArgs() |
||||
b, _ := goproto.Marshal(args) |
||||
log.Printf("message size: %d bytes\n\n", len(b)) |
||||
|
||||
var trans uint64 |
||||
var transOK uint64 |
||||
d := make([][]int64, c) |
||||
for i := 0; i < c; i++ { |
||||
dt := make([]int64, 0, m) |
||||
d = append(d, dt) |
||||
} |
||||
var cli proto.HelloClient |
||||
if *isWarden { |
||||
cli = wardenCli() |
||||
} else { |
||||
cli = grpcCli() |
||||
} |
||||
//warmup
|
||||
cli.Say(context.Background(), args) |
||||
|
||||
totalT := time.Now().UnixNano() |
||||
for i := 0; i < c; i++ { |
||||
go func(i int) { |
||||
for j := 0; j < m; j++ { |
||||
t := time.Now().UnixNano() |
||||
reply, err := cli.Say(context.Background(), args) |
||||
t = time.Now().UnixNano() - t |
||||
d[i] = append(d[i], t) |
||||
if err == nil && reply.Field1 == "OK" { |
||||
atomic.AddUint64(&transOK, 1) |
||||
} |
||||
atomic.AddUint64(&trans, 1) |
||||
} |
||||
wg.Done() |
||||
}(i) |
||||
|
||||
} |
||||
wg.Wait() |
||||
|
||||
totalT = time.Now().UnixNano() - totalT |
||||
totalT = totalT / 1e6 |
||||
log.Printf("took %d ms for %d requests\n", totalT, *total) |
||||
totalD := make([]int64, 0, *total) |
||||
for _, k := range d { |
||||
totalD = append(totalD, k...) |
||||
} |
||||
totalD2 := make([]float64, 0, *total) |
||||
for _, k := range totalD { |
||||
totalD2 = append(totalD2, float64(k)) |
||||
} |
||||
|
||||
mean, _ := stats.Mean(totalD2) |
||||
median, _ := stats.Median(totalD2) |
||||
max, _ := stats.Max(totalD2) |
||||
min, _ := stats.Min(totalD2) |
||||
tp99, _ := stats.Percentile(totalD2, 99) |
||||
tp999, _ := stats.Percentile(totalD2, 99.9) |
||||
|
||||
log.Printf("sent requests : %d\n", *total) |
||||
log.Printf("received requests_OK : %d\n", atomic.LoadUint64(&transOK)) |
||||
log.Printf("throughput (TPS) : %d\n", int64(c*m)*1000/totalT) |
||||
log.Printf("mean: %v ms, median: %v ms, max: %v ms, min: %v ms, p99: %v ms, p999:%v ms\n", mean/1e6, median/1e6, max/1e6, min/1e6, tp99/1e6, tp999/1e6) |
||||
|
||||
} |
||||
|
||||
func prepareArgs() *proto.BenchmarkMessage { |
||||
b := true |
||||
var i int32 = 120000 |
||||
var i64 int64 = 98765432101234 |
||||
var s = "许多往事在眼前一幕一幕,变的那麼模糊" |
||||
repeat := *strLen / (8 * 54) |
||||
if repeat == 0 { |
||||
repeat = 1 |
||||
} |
||||
var str string |
||||
for i := 0; i < repeat; i++ { |
||||
str += s |
||||
} |
||||
var args proto.BenchmarkMessage |
||||
|
||||
v := reflect.ValueOf(&args).Elem() |
||||
num := v.NumField() |
||||
for k := 0; k < num; k++ { |
||||
field := v.Field(k) |
||||
if field.Type().Kind() == reflect.Ptr { |
||||
switch v.Field(k).Type().Elem().Kind() { |
||||
case reflect.Int, reflect.Int32: |
||||
field.Set(reflect.ValueOf(&i)) |
||||
case reflect.Int64: |
||||
field.Set(reflect.ValueOf(&i64)) |
||||
case reflect.Bool: |
||||
field.Set(reflect.ValueOf(&b)) |
||||
case reflect.String: |
||||
field.Set(reflect.ValueOf(&str)) |
||||
} |
||||
} else { |
||||
switch field.Kind() { |
||||
case reflect.Int, reflect.Int32, reflect.Int64: |
||||
field.SetInt(9876543) |
||||
case reflect.Bool: |
||||
field.SetBool(true) |
||||
case reflect.String: |
||||
field.SetString(str) |
||||
} |
||||
} |
||||
} |
||||
return &args |
||||
} |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,60 @@ |
||||
syntax = "proto3"; |
||||
package proto; |
||||
|
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||
|
||||
option optimize_for = SPEED; |
||||
option (gogoproto.goproto_enum_prefix_all) = false; |
||||
option (gogoproto.goproto_getters_all) = false; |
||||
option (gogoproto.unmarshaler_all) = true; |
||||
option (gogoproto.marshaler_all) = true; |
||||
option (gogoproto.sizer_all) = true; |
||||
|
||||
service Hello { |
||||
// Sends a greeting |
||||
rpc Say (BenchmarkMessage) returns (BenchmarkMessage) {} |
||||
} |
||||
|
||||
|
||||
message BenchmarkMessage { |
||||
string field1 = 1; |
||||
string field9 = 9; |
||||
string field18 = 18; |
||||
bool field80 = 80; |
||||
bool field81 = 81; |
||||
int32 field2 = 2; |
||||
int32 field3 = 3; |
||||
int32 field280 = 280; |
||||
int32 field6 = 6; |
||||
int64 field22 = 22; |
||||
string field4 = 4; |
||||
fixed64 field5 = 5; |
||||
bool field59 = 59; |
||||
string field7 = 7; |
||||
int32 field16 = 16; |
||||
int32 field130 = 130; |
||||
bool field12 = 12; |
||||
bool field17 = 17; |
||||
bool field13 = 13; |
||||
bool field14 = 14; |
||||
int32 field104 = 104; |
||||
int32 field100 = 100; |
||||
int32 field101 = 101; |
||||
string field102 = 102; |
||||
string field103 = 103; |
||||
int32 field29 = 29; |
||||
bool field30 = 30; |
||||
int32 field60 = 60; |
||||
int32 field271 = 271; |
||||
int32 field272 = 272; |
||||
int32 field150 = 150; |
||||
int32 field23 = 23; |
||||
bool field24 = 24 ; |
||||
int32 field25 = 25 ; |
||||
bool field78 = 78; |
||||
int32 field67 = 67; |
||||
int32 field68 = 68; |
||||
int32 field128 = 128; |
||||
string field129 = 129; |
||||
int32 field131 = 131; |
||||
} |
@ -0,0 +1,103 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"log" |
||||
"net" |
||||
"net/http" |
||||
_ "net/http/pprof" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||
"google.golang.org/grpc" |
||||
) |
||||
|
||||
const ( |
||||
iws = 65535 * 1000 |
||||
iwsc = 65535 * 10000 |
||||
readBuffer = 32 * 1024 |
||||
writeBuffer = 32 * 1024 |
||||
) |
||||
|
||||
var reqNum uint64 |
||||
|
||||
type Hello struct{} |
||||
|
||||
func (t *Hello) Say(ctx context.Context, args *proto.BenchmarkMessage) (reply *proto.BenchmarkMessage, err error) { |
||||
s := "OK" |
||||
var i int32 = 100 |
||||
args.Field1 = s |
||||
args.Field2 = i |
||||
atomic.AddUint64(&reqNum, 1) |
||||
return args, nil |
||||
} |
||||
|
||||
var host = flag.String("s", "0.0.0.0:8972", "listened ip and port") |
||||
var isWarden = flag.Bool("w", true, "is warden or grpc client") |
||||
|
||||
func main() { |
||||
go func() { |
||||
log.Println("run http at :6060") |
||||
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { |
||||
h := promhttp.Handler() |
||||
h.ServeHTTP(w, r) |
||||
}) |
||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) |
||||
}() |
||||
|
||||
flag.Parse() |
||||
|
||||
go stat() |
||||
if *isWarden { |
||||
runWarden() |
||||
} else { |
||||
runGrpc() |
||||
} |
||||
} |
||||
|
||||
func runGrpc() { |
||||
log.Println("run grpc") |
||||
lis, err := net.Listen("tcp", *host) |
||||
if err != nil { |
||||
log.Fatalf("failed to listen: %v", err) |
||||
} |
||||
s := grpc.NewServer(grpc.InitialWindowSize(iws), |
||||
grpc.InitialConnWindowSize(iwsc), |
||||
grpc.ReadBufferSize(readBuffer), |
||||
grpc.WriteBufferSize(writeBuffer)) |
||||
proto.RegisterHelloServer(s, &Hello{}) |
||||
s.Serve(lis) |
||||
} |
||||
|
||||
func runWarden() { |
||||
log.Println("run warden") |
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second * 3)}, |
||||
grpc.InitialWindowSize(iws), |
||||
grpc.InitialConnWindowSize(iwsc), |
||||
grpc.ReadBufferSize(readBuffer), |
||||
grpc.WriteBufferSize(writeBuffer)) |
||||
proto.RegisterHelloServer(s.Server(), &Hello{}) |
||||
s.Run(*host) |
||||
} |
||||
|
||||
func stat() { |
||||
ticker := time.NewTicker(time.Second * 5) |
||||
defer ticker.Stop() |
||||
var last uint64 |
||||
lastTs := uint64(time.Now().UnixNano()) |
||||
for { |
||||
<-ticker.C |
||||
now := atomic.LoadUint64(&reqNum) |
||||
nowTs := uint64(time.Now().UnixNano()) |
||||
qps := (now - last) * 1e6 / ((nowTs - lastTs) / 1e3) |
||||
last = now |
||||
lastTs = nowTs |
||||
log.Println("qps:", qps) |
||||
} |
||||
} |
@ -0,0 +1,15 @@ |
||||
#!/bin/bash |
||||
go build -o client greeter_client.go |
||||
echo size 100 concurrent 30 |
||||
./client -s 100 -c 30 |
||||
echo size 1000 concurrent 30 |
||||
./client -s 1000 -c 30 |
||||
echo size 10000 concurrent 30 |
||||
./client -s 10000 -c 30 |
||||
echo size 100 concurrent 300 |
||||
./client -s 100 -c 300 |
||||
echo size 1000 concurrent 300 |
||||
./client -s 1000 -c 300 |
||||
echo size 10000 concurrent 300 |
||||
./client -s 10000 -c 300 |
||||
rm client |
@ -0,0 +1,85 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
"math/rand" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
) |
||||
|
||||
var ( |
||||
ccf = &warden.ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
} |
||||
cli pb.GreeterClient |
||||
wg sync.WaitGroup |
||||
reqSize int |
||||
concurrency int |
||||
request int |
||||
all int64 |
||||
) |
||||
|
||||
func init() { |
||||
flag.IntVar(&reqSize, "s", 10, "request size") |
||||
flag.IntVar(&concurrency, "c", 10, "concurrency") |
||||
flag.IntVar(&request, "r", 1000, "request per routine") |
||||
} |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
name := randSeq(reqSize) |
||||
cli = newClient() |
||||
for i := 0; i < concurrency; i++ { |
||||
wg.Add(1) |
||||
go sayHello(&pb.HelloRequest{Name: name}) |
||||
} |
||||
wg.Wait() |
||||
fmt.Printf("per request cost %v\n", all/int64(request*concurrency)) |
||||
|
||||
} |
||||
|
||||
func sayHello(in *pb.HelloRequest) { |
||||
defer wg.Done() |
||||
now := time.Now() |
||||
for i := 0; i < request; i++ { |
||||
cli.SayHello(context.TODO(), in) |
||||
} |
||||
delta := time.Since(now) |
||||
atomic.AddInt64(&all, int64(delta/time.Millisecond)) |
||||
} |
||||
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") |
||||
|
||||
func randSeq(n int) string { |
||||
b := make([]rune, n) |
||||
for i := range b { |
||||
b[i] = letters[rand.Intn(len(letters))] |
||||
} |
||||
return string(b) |
||||
} |
||||
|
||||
func newClient() (cli pb.GreeterClient) { |
||||
client := warden.NewClient(ccf) |
||||
conn, err := client.Dial(context.TODO(), "127.0.0.1:9999") |
||||
if err != nil { |
||||
return |
||||
} |
||||
cli = pb.NewGreeterClient(conn) |
||||
return |
||||
} |
@ -0,0 +1,50 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"net/http" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||
) |
||||
|
||||
var ( |
||||
config = &warden.ServerConfig{Timeout: xtime.Duration(time.Second)} |
||||
) |
||||
|
||||
func main() { |
||||
newServer() |
||||
} |
||||
|
||||
type hello struct { |
||||
} |
||||
|
||||
func (s *hello) SayHello(c context.Context, in *pb.HelloRequest) (out *pb.HelloReply, err error) { |
||||
out = new(pb.HelloReply) |
||||
out.Message = in.Name |
||||
return |
||||
} |
||||
|
||||
func (s *hello) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
return nil |
||||
} |
||||
func newServer() { |
||||
server := warden.NewServer(config) |
||||
pb.RegisterGreeterServer(server.Server(), &hello{}) |
||||
go func() { |
||||
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { |
||||
h := promhttp.Handler() |
||||
h.ServeHTTP(w, r) |
||||
}) |
||||
http.ListenAndServe("0.0.0.0:9998", nil) |
||||
}() |
||||
err := server.Run(":9999") |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
} |
@ -0,0 +1,53 @@ |
||||
package codec |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/json" |
||||
|
||||
"github.com/gogo/protobuf/jsonpb" |
||||
"github.com/gogo/protobuf/proto" |
||||
"google.golang.org/grpc/encoding" |
||||
) |
||||
|
||||
//Reference https://jbrandhorst.com/post/grpc-json/
|
||||
func init() { |
||||
encoding.RegisterCodec(JSON{ |
||||
Marshaler: jsonpb.Marshaler{ |
||||
EmitDefaults: true, |
||||
OrigName: true, |
||||
}, |
||||
}) |
||||
} |
||||
|
||||
// JSON is impl of encoding.Codec
|
||||
type JSON struct { |
||||
jsonpb.Marshaler |
||||
jsonpb.Unmarshaler |
||||
} |
||||
|
||||
// Name is name of JSON
|
||||
func (j JSON) Name() string { |
||||
return "json" |
||||
} |
||||
|
||||
// Marshal is json marshal
|
||||
func (j JSON) Marshal(v interface{}) (out []byte, err error) { |
||||
if pm, ok := v.(proto.Message); ok { |
||||
b := new(bytes.Buffer) |
||||
err := j.Marshaler.Marshal(b, pm) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return b.Bytes(), nil |
||||
} |
||||
return json.Marshal(v) |
||||
} |
||||
|
||||
// Unmarshal is json unmarshal
|
||||
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) { |
||||
if pm, ok := v.(proto.Message); ok { |
||||
b := bytes.NewBuffer(data) |
||||
return j.Unmarshaler.Unmarshal(b, pm) |
||||
} |
||||
return json.Unmarshal(data, v) |
||||
} |
@ -0,0 +1,31 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
) |
||||
|
||||
// usage: ./client -grpc.target=test.service=127.0.0.1:8080
|
||||
func main() { |
||||
log.Init(&log.Config{Stdout: true}) |
||||
flag.Parse() |
||||
conn, err := warden.NewClient(nil).Dial(context.Background(), "127.0.0.1:8081") |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
cli := pb.NewGreeterClient(conn) |
||||
normalCall(cli) |
||||
} |
||||
|
||||
func normalCall(cli pb.GreeterClient) { |
||||
reply, err := cli.SayHello(context.Background(), &pb.HelloRequest{Name: "tom", Age: 23}) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
fmt.Println("get reply:", *reply) |
||||
} |
@ -0,0 +1,191 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"flag" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"math/rand" |
||||
"net/http" |
||||
"os" |
||||
"strings" |
||||
|
||||
"github.com/gogo/protobuf/jsonpb" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/credentials" |
||||
"google.golang.org/grpc/encoding" |
||||
) |
||||
|
||||
// Reply for test
|
||||
type Reply struct { |
||||
res []byte |
||||
} |
||||
|
||||
type Discovery struct { |
||||
HttpClient *http.Client |
||||
Nodes []string |
||||
} |
||||
|
||||
var ( |
||||
data string |
||||
file string |
||||
method string |
||||
addr string |
||||
tlsCert string |
||||
tlsServerName string |
||||
appID string |
||||
env string |
||||
) |
||||
|
||||
//Reference https://jbrandhorst.com/post/grpc-json/
|
||||
func init() { |
||||
encoding.RegisterCodec(JSON{ |
||||
Marshaler: jsonpb.Marshaler{ |
||||
EmitDefaults: true, |
||||
OrigName: true, |
||||
}, |
||||
}) |
||||
flag.StringVar(&data, "data", `{"name":"longxia","age":19}`, `{"name":"longxia","age":19}`) |
||||
flag.StringVar(&file, "file", ``, `./data.json`) |
||||
flag.StringVar(&method, "method", "/testproto.Greeter/SayHello", `/testproto.Greeter/SayHello`) |
||||
flag.StringVar(&addr, "addr", "127.0.0.1:8080", `127.0.0.1:8080`) |
||||
flag.StringVar(&tlsCert, "cert", "", `./cert.pem`) |
||||
flag.StringVar(&tlsServerName, "server_name", "", `hello_server`) |
||||
flag.StringVar(&appID, "appid", "", `appid`) |
||||
flag.StringVar(&env, "env", "", `env`) |
||||
} |
||||
|
||||
// 该example因为使用的是json传输格式所以只能用于调试或测试,用于线上会导致性能下降
|
||||
// 使用方法:
|
||||
// ./grpcDebug -data='{"name":"xia","age":19}' -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
|
||||
// ./grpcDebug -file=data.json -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
|
||||
// DEPLOY_ENV=uat ./grpcDebug -appid=main.community.reply-service -method=/reply.service.v1.Reply/ReplyInfoCache -data='{"rp_id"=1493769244}'
|
||||
func main() { |
||||
flag.Parse() |
||||
opts := []grpc.DialOption{ |
||||
grpc.WithInsecure(), |
||||
grpc.WithDefaultCallOptions(grpc.CallContentSubtype(JSON{}.Name())), |
||||
} |
||||
if tlsCert != "" { |
||||
creds, err := credentials.NewClientTLSFromFile(tlsCert, tlsServerName) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
opts = append(opts, grpc.WithTransportCredentials(creds)) |
||||
} |
||||
if file != "" { |
||||
content, err := ioutil.ReadFile(file) |
||||
if err != nil { |
||||
fmt.Println("ioutil.ReadFile %s failed!err:=%v", file, err) |
||||
os.Exit(1) |
||||
} |
||||
if len(content) > 0 { |
||||
data = string(content) |
||||
} |
||||
} |
||||
if appID != "" { |
||||
addr = ipFromDiscovery(appID, env) |
||||
} |
||||
conn, err := grpc.Dial(addr, opts...) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
var reply Reply |
||||
err = grpc.Invoke(context.Background(), method, []byte(data), &reply, conn) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
fmt.Println(string(reply.res)) |
||||
} |
||||
|
||||
func ipFromDiscovery(appID, env string) string { |
||||
d := &Discovery{ |
||||
Nodes: []string{"discovery.bilibili.co", "api.bilibili.co"}, |
||||
HttpClient: http.DefaultClient, |
||||
} |
||||
deployEnv := os.Getenv("DEPLOY_ENV") |
||||
if deployEnv != "" { |
||||
env = deployEnv |
||||
} |
||||
return d.addr(appID, env, d.nodes()) |
||||
} |
||||
|
||||
func (d *Discovery) nodes() (addrs []string) { |
||||
res := new(struct { |
||||
Code int `json:"code"` |
||||
Data []struct { |
||||
Addr string `json:"addr"` |
||||
} `json:"data"` |
||||
}) |
||||
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/nodes", d.Nodes[rand.Intn(len(d.Nodes))])) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
defer resp.Body.Close() |
||||
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { |
||||
panic(err) |
||||
} |
||||
for _, data := range res.Data { |
||||
addrs = append(addrs, data.Addr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (d *Discovery) addr(appID, env string, nodes []string) (ip string) { |
||||
res := new(struct { |
||||
Code int `json:"code"` |
||||
Message string `json:"message"` |
||||
Data map[string]*struct { |
||||
ZoneInstances map[string][]*struct { |
||||
AppID string `json:"appid"` |
||||
Addrs []string `json:"addrs"` |
||||
} `json:"zone_instances"` |
||||
} `json:"data"` |
||||
}) |
||||
host, _ := os.Hostname() |
||||
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/polls?appid=%s&env=%s&hostname=%s", nodes[rand.Intn(len(nodes))], appID, env, host)) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
defer resp.Body.Close() |
||||
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { |
||||
panic(err) |
||||
} |
||||
for _, data := range res.Data { |
||||
for _, zoneInstance := range data.ZoneInstances { |
||||
for _, instance := range zoneInstance { |
||||
if instance.AppID == appID { |
||||
for _, addr := range instance.Addrs { |
||||
if strings.Contains(addr, "grpc://") { |
||||
return strings.Replace(addr, "grpc://", "", -1) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// JSON is impl of encoding.Codec
|
||||
type JSON struct { |
||||
jsonpb.Marshaler |
||||
jsonpb.Unmarshaler |
||||
} |
||||
|
||||
// Name is name of JSON
|
||||
func (j JSON) Name() string { |
||||
return "json" |
||||
} |
||||
|
||||
// Marshal is json marshal
|
||||
func (j JSON) Marshal(v interface{}) (out []byte, err error) { |
||||
return v.([]byte), nil |
||||
} |
||||
|
||||
// Unmarshal is json unmarshal
|
||||
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) { |
||||
v.(*Reply).res = data |
||||
return nil |
||||
} |
@ -0,0 +1 @@ |
||||
{"name":"xia","age":19} |
@ -0,0 +1,108 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"os" |
||||
"os/signal" |
||||
"syscall" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
epb "github.com/bilibili/kratos/pkg/ecode/pb" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/golang/protobuf/ptypes" |
||||
"google.golang.org/grpc" |
||||
) |
||||
|
||||
type helloServer struct { |
||||
addr string |
||||
} |
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
if in.Name == "err_detail_test" { |
||||
any, _ := ptypes.MarshalAny(&pb.HelloReply{Success: true, Message: "this is test detail"}) |
||||
err := epb.From(ecode.AccessDenied) |
||||
err.ErrDetail = any |
||||
return nil, err |
||||
} |
||||
return &pb.HelloReply{Message: fmt.Sprintf("hello %s from %s", in.Name, s.addr)}, nil |
||||
} |
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
for i := 0; i < 3; i++ { |
||||
in, err := ss.Recv() |
||||
if err == io.EOF { |
||||
return nil |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true} |
||||
err = ss.Send(ret) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func runServer(addr string) *warden.Server { |
||||
server := warden.NewServer(&warden.ServerConfig{ |
||||
//服务端每个请求的默认超时时间
|
||||
Timeout: xtime.Duration(time.Second), |
||||
}) |
||||
server.Use(middleware()) |
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{addr: addr}) |
||||
go func() { |
||||
err := server.Run(addr) |
||||
if err != nil { |
||||
panic("run server failed!" + err.Error()) |
||||
} |
||||
}() |
||||
return server |
||||
} |
||||
|
||||
func main() { |
||||
log.Init(&log.Config{Stdout: true}) |
||||
server := runServer("0.0.0.0:8081") |
||||
signalHandler(server) |
||||
} |
||||
|
||||
//类似于中间件
|
||||
func middleware() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
||||
//记录调用方法
|
||||
log.Info("method:%s", info.FullMethod) |
||||
//call chain
|
||||
resp, err = handler(ctx, req) |
||||
return |
||||
} |
||||
} |
||||
|
||||
func signalHandler(s *warden.Server) { |
||||
var ( |
||||
ch = make(chan os.Signal, 1) |
||||
) |
||||
signal.Notify(ch, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) |
||||
for { |
||||
si := <-ch |
||||
switch si { |
||||
case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: |
||||
log.Info("get a signal %s, stop the consume process", si.String()) |
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) |
||||
defer cancel() |
||||
//gracefully shutdown with timeout
|
||||
s.Shutdown(ctx) |
||||
return |
||||
case syscall.SIGHUP: |
||||
default: |
||||
return |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,11 @@ |
||||
package metadata |
||||
|
||||
const ( |
||||
CPUUsage = "cpu_usage" |
||||
) |
||||
|
||||
// MD is context metadata for balancer and resolver
|
||||
type MD struct { |
||||
Weight uint64 |
||||
Color string |
||||
} |
@ -0,0 +1,642 @@ |
||||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: hello.proto
|
||||
|
||||
/* |
||||
Package testproto is a generated protocol buffer package. |
||||
|
||||
It is generated from these files: |
||||
hello.proto |
||||
|
||||
It has these top-level messages: |
||||
HelloRequest |
||||
HelloReply |
||||
*/ |
||||
package testproto |
||||
|
||||
import proto "github.com/golang/protobuf/proto" |
||||
import fmt "fmt" |
||||
import math "math" |
||||
import _ "github.com/gogo/protobuf/gogoproto" |
||||
|
||||
import context "golang.org/x/net/context" |
||||
import grpc "google.golang.org/grpc" |
||||
|
||||
import io "io" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = fmt.Errorf |
||||
var _ = math.Inf |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
// The request message containing the user's name.
|
||||
type HelloRequest struct { |
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name" validate:"required"` |
||||
Age int32 `protobuf:"varint,2,opt,name=age,proto3" json:"age" validate:"min=0"` |
||||
} |
||||
|
||||
func (m *HelloRequest) Reset() { *m = HelloRequest{} } |
||||
func (m *HelloRequest) String() string { return proto.CompactTextString(m) } |
||||
func (*HelloRequest) ProtoMessage() {} |
||||
func (*HelloRequest) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{0} } |
||||
|
||||
// The response message containing the greetings
|
||||
type HelloReply struct { |
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` |
||||
Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` |
||||
} |
||||
|
||||
func (m *HelloReply) Reset() { *m = HelloReply{} } |
||||
func (m *HelloReply) String() string { return proto.CompactTextString(m) } |
||||
func (*HelloReply) ProtoMessage() {} |
||||
func (*HelloReply) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{1} } |
||||
|
||||
func init() { |
||||
proto.RegisterType((*HelloRequest)(nil), "testproto.HelloRequest") |
||||
proto.RegisterType((*HelloReply)(nil), "testproto.HelloReply") |
||||
} |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context |
||||
var _ grpc.ClientConn |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4 |
||||
|
||||
// Client API for Greeter service
|
||||
|
||||
type GreeterClient interface { |
||||
// Sends a greeting
|
||||
SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) |
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
|
||||
StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error) |
||||
} |
||||
|
||||
type greeterClient struct { |
||||
cc *grpc.ClientConn |
||||
} |
||||
|
||||
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { |
||||
return &greeterClient{cc} |
||||
} |
||||
|
||||
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { |
||||
out := new(HelloReply) |
||||
err := grpc.Invoke(ctx, "/testproto.Greeter/SayHello", in, out, c.cc, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *greeterClient) StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error) { |
||||
stream, err := grpc.NewClientStream(ctx, &_Greeter_serviceDesc.Streams[0], c.cc, "/testproto.Greeter/StreamHello", opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
x := &greeterStreamHelloClient{stream} |
||||
return x, nil |
||||
} |
||||
|
||||
type Greeter_StreamHelloClient interface { |
||||
Send(*HelloRequest) error |
||||
Recv() (*HelloReply, error) |
||||
grpc.ClientStream |
||||
} |
||||
|
||||
type greeterStreamHelloClient struct { |
||||
grpc.ClientStream |
||||
} |
||||
|
||||
func (x *greeterStreamHelloClient) Send(m *HelloRequest) error { |
||||
return x.ClientStream.SendMsg(m) |
||||
} |
||||
|
||||
func (x *greeterStreamHelloClient) Recv() (*HelloReply, error) { |
||||
m := new(HelloReply) |
||||
if err := x.ClientStream.RecvMsg(m); err != nil { |
||||
return nil, err |
||||
} |
||||
return m, nil |
||||
} |
||||
|
||||
// Server API for Greeter service
|
||||
|
||||
type GreeterServer interface { |
||||
// Sends a greeting
|
||||
SayHello(context.Context, *HelloRequest) (*HelloReply, error) |
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
|
||||
StreamHello(Greeter_StreamHelloServer) error |
||||
} |
||||
|
||||
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { |
||||
s.RegisterService(&_Greeter_serviceDesc, srv) |
||||
} |
||||
|
||||
func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(HelloRequest) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(GreeterServer).SayHello(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/testproto.Greeter/SayHello", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Greeter_StreamHello_Handler(srv interface{}, stream grpc.ServerStream) error { |
||||
return srv.(GreeterServer).StreamHello(&greeterStreamHelloServer{stream}) |
||||
} |
||||
|
||||
type Greeter_StreamHelloServer interface { |
||||
Send(*HelloReply) error |
||||
Recv() (*HelloRequest, error) |
||||
grpc.ServerStream |
||||
} |
||||
|
||||
type greeterStreamHelloServer struct { |
||||
grpc.ServerStream |
||||
} |
||||
|
||||
func (x *greeterStreamHelloServer) Send(m *HelloReply) error { |
||||
return x.ServerStream.SendMsg(m) |
||||
} |
||||
|
||||
func (x *greeterStreamHelloServer) Recv() (*HelloRequest, error) { |
||||
m := new(HelloRequest) |
||||
if err := x.ServerStream.RecvMsg(m); err != nil { |
||||
return nil, err |
||||
} |
||||
return m, nil |
||||
} |
||||
|
||||
var _Greeter_serviceDesc = grpc.ServiceDesc{ |
||||
ServiceName: "testproto.Greeter", |
||||
HandlerType: (*GreeterServer)(nil), |
||||
Methods: []grpc.MethodDesc{ |
||||
{ |
||||
MethodName: "SayHello", |
||||
Handler: _Greeter_SayHello_Handler, |
||||
}, |
||||
}, |
||||
Streams: []grpc.StreamDesc{ |
||||
{ |
||||
StreamName: "StreamHello", |
||||
Handler: _Greeter_StreamHello_Handler, |
||||
ServerStreams: true, |
||||
ClientStreams: true, |
||||
}, |
||||
}, |
||||
Metadata: "hello.proto", |
||||
} |
||||
|
||||
func (m *HelloRequest) Marshal() (dAtA []byte, err error) { |
||||
size := m.Size() |
||||
dAtA = make([]byte, size) |
||||
n, err := m.MarshalTo(dAtA) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return dAtA[:n], nil |
||||
} |
||||
|
||||
func (m *HelloRequest) MarshalTo(dAtA []byte) (int, error) { |
||||
var i int |
||||
_ = i |
||||
var l int |
||||
_ = l |
||||
if len(m.Name) > 0 { |
||||
dAtA[i] = 0xa |
||||
i++ |
||||
i = encodeVarintHello(dAtA, i, uint64(len(m.Name))) |
||||
i += copy(dAtA[i:], m.Name) |
||||
} |
||||
if m.Age != 0 { |
||||
dAtA[i] = 0x10 |
||||
i++ |
||||
i = encodeVarintHello(dAtA, i, uint64(m.Age)) |
||||
} |
||||
return i, nil |
||||
} |
||||
|
||||
func (m *HelloReply) Marshal() (dAtA []byte, err error) { |
||||
size := m.Size() |
||||
dAtA = make([]byte, size) |
||||
n, err := m.MarshalTo(dAtA) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return dAtA[:n], nil |
||||
} |
||||
|
||||
func (m *HelloReply) MarshalTo(dAtA []byte) (int, error) { |
||||
var i int |
||||
_ = i |
||||
var l int |
||||
_ = l |
||||
if len(m.Message) > 0 { |
||||
dAtA[i] = 0xa |
||||
i++ |
||||
i = encodeVarintHello(dAtA, i, uint64(len(m.Message))) |
||||
i += copy(dAtA[i:], m.Message) |
||||
} |
||||
if m.Success { |
||||
dAtA[i] = 0x10 |
||||
i++ |
||||
if m.Success { |
||||
dAtA[i] = 1 |
||||
} else { |
||||
dAtA[i] = 0 |
||||
} |
||||
i++ |
||||
} |
||||
return i, nil |
||||
} |
||||
|
||||
func encodeVarintHello(dAtA []byte, offset int, v uint64) int { |
||||
for v >= 1<<7 { |
||||
dAtA[offset] = uint8(v&0x7f | 0x80) |
||||
v >>= 7 |
||||
offset++ |
||||
} |
||||
dAtA[offset] = uint8(v) |
||||
return offset + 1 |
||||
} |
||||
func (m *HelloRequest) Size() (n int) { |
||||
var l int |
||||
_ = l |
||||
l = len(m.Name) |
||||
if l > 0 { |
||||
n += 1 + l + sovHello(uint64(l)) |
||||
} |
||||
if m.Age != 0 { |
||||
n += 1 + sovHello(uint64(m.Age)) |
||||
} |
||||
return n |
||||
} |
||||
|
||||
func (m *HelloReply) Size() (n int) { |
||||
var l int |
||||
_ = l |
||||
l = len(m.Message) |
||||
if l > 0 { |
||||
n += 1 + l + sovHello(uint64(l)) |
||||
} |
||||
if m.Success { |
||||
n += 2 |
||||
} |
||||
return n |
||||
} |
||||
|
||||
func sovHello(x uint64) (n int) { |
||||
for { |
||||
n++ |
||||
x >>= 7 |
||||
if x == 0 { |
||||
break |
||||
} |
||||
} |
||||
return n |
||||
} |
||||
func sozHello(x uint64) (n int) { |
||||
return sovHello(uint64((x << 1) ^ uint64((int64(x) >> 63)))) |
||||
} |
||||
func (m *HelloRequest) Unmarshal(dAtA []byte) error { |
||||
l := len(dAtA) |
||||
iNdEx := 0 |
||||
for iNdEx < l { |
||||
preIndex := iNdEx |
||||
var wire uint64 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
wire |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
fieldNum := int32(wire >> 3) |
||||
wireType := int(wire & 0x7) |
||||
if wireType == 4 { |
||||
return fmt.Errorf("proto: HelloRequest: wiretype end group for non-group") |
||||
} |
||||
if fieldNum <= 0 { |
||||
return fmt.Errorf("proto: HelloRequest: illegal tag %d (wire type %d)", fieldNum, wire) |
||||
} |
||||
switch fieldNum { |
||||
case 1: |
||||
if wireType != 2 { |
||||
return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) |
||||
} |
||||
var stringLen uint64 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
stringLen |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
intStringLen := int(stringLen) |
||||
if intStringLen < 0 { |
||||
return ErrInvalidLengthHello |
||||
} |
||||
postIndex := iNdEx + intStringLen |
||||
if postIndex > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
m.Name = string(dAtA[iNdEx:postIndex]) |
||||
iNdEx = postIndex |
||||
case 2: |
||||
if wireType != 0 { |
||||
return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType) |
||||
} |
||||
m.Age = 0 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
m.Age |= (int32(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
default: |
||||
iNdEx = preIndex |
||||
skippy, err := skipHello(dAtA[iNdEx:]) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if skippy < 0 { |
||||
return ErrInvalidLengthHello |
||||
} |
||||
if (iNdEx + skippy) > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
iNdEx += skippy |
||||
} |
||||
} |
||||
|
||||
if iNdEx > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
return nil |
||||
} |
||||
func (m *HelloReply) Unmarshal(dAtA []byte) error { |
||||
l := len(dAtA) |
||||
iNdEx := 0 |
||||
for iNdEx < l { |
||||
preIndex := iNdEx |
||||
var wire uint64 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
wire |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
fieldNum := int32(wire >> 3) |
||||
wireType := int(wire & 0x7) |
||||
if wireType == 4 { |
||||
return fmt.Errorf("proto: HelloReply: wiretype end group for non-group") |
||||
} |
||||
if fieldNum <= 0 { |
||||
return fmt.Errorf("proto: HelloReply: illegal tag %d (wire type %d)", fieldNum, wire) |
||||
} |
||||
switch fieldNum { |
||||
case 1: |
||||
if wireType != 2 { |
||||
return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType) |
||||
} |
||||
var stringLen uint64 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
stringLen |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
intStringLen := int(stringLen) |
||||
if intStringLen < 0 { |
||||
return ErrInvalidLengthHello |
||||
} |
||||
postIndex := iNdEx + intStringLen |
||||
if postIndex > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
m.Message = string(dAtA[iNdEx:postIndex]) |
||||
iNdEx = postIndex |
||||
case 2: |
||||
if wireType != 0 { |
||||
return fmt.Errorf("proto: wrong wireType = %d for field Success", wireType) |
||||
} |
||||
var v int |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
v |= (int(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
m.Success = bool(v != 0) |
||||
default: |
||||
iNdEx = preIndex |
||||
skippy, err := skipHello(dAtA[iNdEx:]) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if skippy < 0 { |
||||
return ErrInvalidLengthHello |
||||
} |
||||
if (iNdEx + skippy) > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
iNdEx += skippy |
||||
} |
||||
} |
||||
|
||||
if iNdEx > l { |
||||
return io.ErrUnexpectedEOF |
||||
} |
||||
return nil |
||||
} |
||||
func skipHello(dAtA []byte) (n int, err error) { |
||||
l := len(dAtA) |
||||
iNdEx := 0 |
||||
for iNdEx < l { |
||||
var wire uint64 |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return 0, ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return 0, io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
wire |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
wireType := int(wire & 0x7) |
||||
switch wireType { |
||||
case 0: |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return 0, ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return 0, io.ErrUnexpectedEOF |
||||
} |
||||
iNdEx++ |
||||
if dAtA[iNdEx-1] < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
return iNdEx, nil |
||||
case 1: |
||||
iNdEx += 8 |
||||
return iNdEx, nil |
||||
case 2: |
||||
var length int |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return 0, ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return 0, io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
length |= (int(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
iNdEx += length |
||||
if length < 0 { |
||||
return 0, ErrInvalidLengthHello |
||||
} |
||||
return iNdEx, nil |
||||
case 3: |
||||
for { |
||||
var innerWire uint64 |
||||
var start int = iNdEx |
||||
for shift := uint(0); ; shift += 7 { |
||||
if shift >= 64 { |
||||
return 0, ErrIntOverflowHello |
||||
} |
||||
if iNdEx >= l { |
||||
return 0, io.ErrUnexpectedEOF |
||||
} |
||||
b := dAtA[iNdEx] |
||||
iNdEx++ |
||||
innerWire |= (uint64(b) & 0x7F) << shift |
||||
if b < 0x80 { |
||||
break |
||||
} |
||||
} |
||||
innerWireType := int(innerWire & 0x7) |
||||
if innerWireType == 4 { |
||||
break |
||||
} |
||||
next, err := skipHello(dAtA[start:]) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
iNdEx = start + next |
||||
} |
||||
return iNdEx, nil |
||||
case 4: |
||||
return iNdEx, nil |
||||
case 5: |
||||
iNdEx += 4 |
||||
return iNdEx, nil |
||||
default: |
||||
return 0, fmt.Errorf("proto: illegal wireType %d", wireType) |
||||
} |
||||
} |
||||
panic("unreachable") |
||||
} |
||||
|
||||
var ( |
||||
ErrInvalidLengthHello = fmt.Errorf("proto: negative length found during unmarshaling") |
||||
ErrIntOverflowHello = fmt.Errorf("proto: integer overflow") |
||||
) |
||||
|
||||
func init() { proto.RegisterFile("hello.proto", fileDescriptorHello) } |
||||
|
||||
var fileDescriptorHello = []byte{ |
||||
// 296 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x90, 0x3f, 0x4e, 0xc3, 0x30, |
||||
0x14, 0xc6, 0x63, 0xfe, 0xb5, 0x75, 0x19, 0x90, 0x11, 0x22, 0x2a, 0x92, 0x53, 0x79, 0xca, 0xd2, |
||||
0xb4, 0xa2, 0x1b, 0x02, 0x09, 0x85, 0x01, 0xe6, 0xf4, 0x04, 0x4e, 0xfa, 0x48, 0x23, 0x25, 0x75, |
||||
0x6a, 0x3b, 0x48, 0xb9, 0x03, 0x07, 0xe0, 0x48, 0x1d, 0x7b, 0x82, 0x88, 0x86, 0xad, 0x63, 0x4f, |
||||
0x80, 0x62, 0x28, 0x20, 0xb1, 0x75, 0x7b, 0x3f, 0x7f, 0xfa, 0x7e, 0x4f, 0x7e, 0xb8, 0x3b, 0x83, |
||||
0x34, 0x15, 0x5e, 0x2e, 0x85, 0x16, 0xa4, 0xa3, 0x41, 0x69, 0x33, 0xf6, 0x06, 0x71, 0xa2, 0x67, |
||||
0x45, 0xe8, 0x45, 0x22, 0x1b, 0xc6, 0x22, 0x16, 0x43, 0xf3, 0x1c, 0x16, 0xcf, 0x86, 0x0c, 0x98, |
||||
0xe9, 0xab, 0xc9, 0x24, 0x3e, 0x7d, 0x6a, 0x44, 0x01, 0x2c, 0x0a, 0x50, 0x9a, 0x8c, 0xf1, 0xd1, |
||||
0x9c, 0x67, 0x60, 0xa3, 0x3e, 0x72, 0x3b, 0xbe, 0xb3, 0xa9, 0x1c, 0xc3, 0xdb, 0xca, 0x39, 0x7f, |
||||
0xe1, 0x69, 0x32, 0xe5, 0x1a, 0x6e, 0x98, 0x84, 0x45, 0x91, 0x48, 0x98, 0xb2, 0xc0, 0x84, 0x64, |
||||
0x80, 0x0f, 0x79, 0x0c, 0xf6, 0x41, 0x1f, 0xb9, 0xc7, 0xfe, 0xd5, 0xa6, 0x72, 0x1a, 0xdc, 0x56, |
||||
0xce, 0xd9, 0x6f, 0x25, 0x4b, 0xe6, 0x77, 0x23, 0x16, 0x34, 0x01, 0xbb, 0xc7, 0xf8, 0x7b, 0x67, |
||||
0x9e, 0x96, 0xc4, 0xc6, 0xad, 0x0c, 0x94, 0x6a, 0x04, 0x66, 0x69, 0xb0, 0xc3, 0x26, 0x51, 0x45, |
||||
0x14, 0x81, 0x52, 0x46, 0xdd, 0x0e, 0x76, 0x78, 0xfd, 0x8a, 0x70, 0xeb, 0x51, 0x02, 0x68, 0x90, |
||||
0xe4, 0x16, 0xb7, 0x27, 0xbc, 0x34, 0x42, 0x72, 0xe9, 0xfd, 0x1c, 0xc2, 0xfb, 0xfb, 0xad, 0xde, |
||||
0xc5, 0xff, 0x20, 0x4f, 0x4b, 0x66, 0x91, 0x07, 0xdc, 0x9d, 0x68, 0x09, 0x3c, 0xdb, 0x53, 0xe0, |
||||
0xa2, 0x11, 0xf2, 0xed, 0xe5, 0x9a, 0x5a, 0xab, 0x35, 0xb5, 0x96, 0x35, 0x45, 0xab, 0x9a, 0xa2, |
||||
0xf7, 0x9a, 0xa2, 0xb7, 0x0f, 0x6a, 0x85, 0x27, 0xa6, 0x31, 0xfe, 0x0c, 0x00, 0x00, 0xff, 0xff, |
||||
0x13, 0x57, 0x88, 0x03, 0xae, 0x01, 0x00, 0x00, |
||||
} |
@ -0,0 +1,33 @@ |
||||
syntax = "proto3"; |
||||
|
||||
package testproto; |
||||
|
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||
|
||||
option (gogoproto.goproto_enum_prefix_all) = false; |
||||
option (gogoproto.goproto_getters_all) = false; |
||||
option (gogoproto.unmarshaler_all) = true; |
||||
option (gogoproto.marshaler_all) = true; |
||||
option (gogoproto.sizer_all) = true; |
||||
option (gogoproto.goproto_registration) = true; |
||||
|
||||
// The greeting service definition. |
||||
service Greeter { |
||||
// Sends a greeting |
||||
rpc SayHello (HelloRequest) returns (HelloReply) {} |
||||
|
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply |
||||
rpc StreamHello(stream HelloRequest) returns (stream HelloReply) {} |
||||
} |
||||
|
||||
// The request message containing the user's name. |
||||
message HelloRequest { |
||||
string name = 1 [(gogoproto.jsontag) = "name", (gogoproto.moretags) = "validate:\"required\""]; |
||||
int32 age = 2 [(gogoproto.jsontag) = "age", (gogoproto.moretags) = "validate:\"min=0\""]; |
||||
} |
||||
|
||||
// The response message containing the greetings |
||||
message HelloReply { |
||||
string message = 1; |
||||
bool success = 2; |
||||
} |
@ -0,0 +1,151 @@ |
||||
package status |
||||
|
||||
import ( |
||||
"context" |
||||
"strconv" |
||||
|
||||
"github.com/golang/protobuf/proto" |
||||
"github.com/golang/protobuf/ptypes" |
||||
"github.com/pkg/errors" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/ecode/pb" |
||||
) |
||||
|
||||
// togRPCCode convert ecode.Codo to gRPC code
|
||||
func togRPCCode(code ecode.Codes) codes.Code { |
||||
switch code.Code() { |
||||
case ecode.OK.Code(): |
||||
return codes.OK |
||||
case ecode.RequestErr.Code(): |
||||
return codes.InvalidArgument |
||||
case ecode.NothingFound.Code(): |
||||
return codes.NotFound |
||||
case ecode.Unauthorized.Code(): |
||||
return codes.Unauthenticated |
||||
case ecode.AccessDenied.Code(): |
||||
return codes.PermissionDenied |
||||
case ecode.LimitExceed.Code(): |
||||
return codes.ResourceExhausted |
||||
case ecode.MethodNotAllowed.Code(): |
||||
return codes.Unimplemented |
||||
case ecode.Deadline.Code(): |
||||
return codes.DeadlineExceeded |
||||
case ecode.ServiceUnavailable.Code(): |
||||
return codes.Unavailable |
||||
} |
||||
return codes.Unknown |
||||
} |
||||
|
||||
func toECode(gst *status.Status) ecode.Code { |
||||
gcode := gst.Code() |
||||
switch gcode { |
||||
case codes.OK: |
||||
return ecode.OK |
||||
case codes.InvalidArgument: |
||||
return ecode.RequestErr |
||||
case codes.NotFound: |
||||
return ecode.NothingFound |
||||
case codes.PermissionDenied: |
||||
return ecode.AccessDenied |
||||
case codes.Unauthenticated: |
||||
return ecode.Unauthorized |
||||
case codes.ResourceExhausted: |
||||
return ecode.LimitExceed |
||||
case codes.Unimplemented: |
||||
return ecode.MethodNotAllowed |
||||
case codes.DeadlineExceeded: |
||||
return ecode.Deadline |
||||
case codes.Unavailable: |
||||
return ecode.ServiceUnavailable |
||||
case codes.Unknown: |
||||
return ecode.String(gst.Message()) |
||||
} |
||||
return ecode.ServerErr |
||||
} |
||||
|
||||
// FromError convert error for service reply and try to convert it to grpc.Status.
|
||||
func FromError(svrErr error) (gst *status.Status) { |
||||
var err error |
||||
svrErr = errors.Cause(svrErr) |
||||
if code, ok := svrErr.(ecode.Codes); ok { |
||||
// TODO: deal with err
|
||||
if gst, err = gRPCStatusFromEcode(code); err == nil { |
||||
return |
||||
} |
||||
} |
||||
// for some special error convert context.Canceled to ecode.Canceled,
|
||||
// context.DeadlineExceeded to ecode.DeadlineExceeded only for raw error
|
||||
// if err be wrapped will not effect.
|
||||
switch svrErr { |
||||
case context.Canceled: |
||||
gst, _ = gRPCStatusFromEcode(ecode.Canceled) |
||||
case context.DeadlineExceeded: |
||||
gst, _ = gRPCStatusFromEcode(ecode.Deadline) |
||||
default: |
||||
gst, _ = status.FromError(svrErr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
func gRPCStatusFromEcode(code ecode.Codes) (*status.Status, error) { |
||||
var st *ecode.Status |
||||
switch v := code.(type) { |
||||
// compatible old pb.Error remove it after nobody use pb.Error.
|
||||
case *pb.Error: |
||||
return status.New(codes.Unknown, v.Error()).WithDetails(v) |
||||
case *ecode.Status: |
||||
st = v |
||||
case ecode.Code: |
||||
st = ecode.FromCode(v) |
||||
default: |
||||
st = ecode.Error(ecode.Code(code.Code()), code.Message()) |
||||
for _, detail := range code.Details() { |
||||
if msg, ok := detail.(proto.Message); ok { |
||||
st.WithDetails(msg) |
||||
} |
||||
} |
||||
} |
||||
// gst := status.New(togRPCCode(st), st.Message())
|
||||
// NOTE: compatible with PHP swoole gRPC put code in status message as string.
|
||||
// gst := status.New(togRPCCode(st), strconv.Itoa(st.Code()))
|
||||
gst := status.New(codes.Unknown, strconv.Itoa(st.Code())) |
||||
pbe := &pb.Error{ErrCode: int32(st.Code()), ErrMessage: gst.Message()} |
||||
// NOTE: server return ecode.Status will be covert to pb.Error details will be ignored
|
||||
// and put it at details[0] for compatible old client
|
||||
return gst.WithDetails(pbe, st.Proto()) |
||||
} |
||||
|
||||
// ToEcode convert grpc.status to ecode.Codes
|
||||
func ToEcode(gst *status.Status) ecode.Codes { |
||||
details := gst.Details() |
||||
// reverse range details, details may contain three case,
|
||||
// if details contain pb.Error and ecode.Status use eocde.Status first.
|
||||
//
|
||||
// Details layout:
|
||||
// pb.Error [0: pb.Error]
|
||||
// both pb.Error and ecode.Status [0: pb.Error, 1: ecode.Status]
|
||||
// ecode.Status [0: ecode.Status]
|
||||
for i := len(details) - 1; i >= 0; i-- { |
||||
detail := details[i] |
||||
// compatible with old pb.Error.
|
||||
if pe, ok := detail.(*pb.Error); ok { |
||||
st := ecode.Error(ecode.Code(pe.ErrCode), pe.ErrMessage) |
||||
if pe.ErrDetail != nil { |
||||
dynMsg := new(ptypes.DynamicAny) |
||||
// TODO deal with unmarshalAny error.
|
||||
if err := ptypes.UnmarshalAny(pe.ErrDetail, dynMsg); err == nil { |
||||
st, _ = st.WithDetails(dynMsg.Message) |
||||
} |
||||
} |
||||
return st |
||||
} |
||||
// convert detail to status only use first detail
|
||||
if pb, ok := detail.(proto.Message); ok { |
||||
return ecode.FromProto(pb) |
||||
} |
||||
} |
||||
return toECode(gst) |
||||
} |
@ -0,0 +1,164 @@ |
||||
package status |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/golang/protobuf/ptypes" |
||||
"github.com/golang/protobuf/ptypes/timestamp" |
||||
pkgerr "github.com/pkg/errors" |
||||
"github.com/stretchr/testify/assert" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/ecode/pb" |
||||
) |
||||
|
||||
func TestCodeConvert(t *testing.T) { |
||||
var table = map[codes.Code]ecode.Code{ |
||||
codes.OK: ecode.OK, |
||||
// codes.Canceled
|
||||
codes.Unknown: ecode.ServerErr, |
||||
codes.InvalidArgument: ecode.RequestErr, |
||||
codes.DeadlineExceeded: ecode.Deadline, |
||||
codes.NotFound: ecode.NothingFound, |
||||
// codes.AlreadyExists
|
||||
codes.PermissionDenied: ecode.AccessDenied, |
||||
codes.ResourceExhausted: ecode.LimitExceed, |
||||
// codes.FailedPrecondition
|
||||
// codes.Aborted
|
||||
// codes.OutOfRange
|
||||
codes.Unimplemented: ecode.MethodNotAllowed, |
||||
codes.Unavailable: ecode.ServiceUnavailable, |
||||
// codes.DataLoss
|
||||
codes.Unauthenticated: ecode.Unauthorized, |
||||
} |
||||
for k, v := range table { |
||||
assert.Equal(t, toECode(status.New(k, "-500")), v) |
||||
} |
||||
for k, v := range table { |
||||
assert.Equal(t, togRPCCode(v), k, fmt.Sprintf("togRPC code error: %d -> %d", v, k)) |
||||
} |
||||
} |
||||
|
||||
func TestNoDetailsConvert(t *testing.T) { |
||||
gst := status.New(codes.Unknown, "-2233") |
||||
assert.Equal(t, toECode(gst).Code(), -2233) |
||||
|
||||
gst = status.New(codes.Internal, "") |
||||
assert.Equal(t, toECode(gst).Code(), -500) |
||||
} |
||||
|
||||
func TestFromError(t *testing.T) { |
||||
t.Run("input general error", func(t *testing.T) { |
||||
err := errors.New("general error") |
||||
gst := FromError(err) |
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
assert.Contains(t, gst.Message(), "general") |
||||
}) |
||||
t.Run("input wrap error", func(t *testing.T) { |
||||
err := pkgerr.Wrap(ecode.RequestErr, "hh") |
||||
gst := FromError(err) |
||||
|
||||
assert.Equal(t, "-400", gst.Message()) |
||||
}) |
||||
t.Run("input ecode.Code", func(t *testing.T) { |
||||
err := ecode.RequestErr |
||||
gst := FromError(err) |
||||
|
||||
//assert.Equal(t, codes.InvalidArgument, gst.Code())
|
||||
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
// NOTE: gst.Message == str(ecode.Code) for compatible php leagcy code
|
||||
assert.Equal(t, err.Message(), gst.Message()) |
||||
}) |
||||
t.Run("input raw Canceled", func(t *testing.T) { |
||||
gst := FromError(context.Canceled) |
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
assert.Equal(t, "-498", gst.Message()) |
||||
}) |
||||
t.Run("input raw DeadlineExceeded", func(t *testing.T) { |
||||
gst := FromError(context.DeadlineExceeded) |
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
assert.Equal(t, "-504", gst.Message()) |
||||
}) |
||||
t.Run("input pb.Error", func(t *testing.T) { |
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
detail, _ := ptypes.MarshalAny(m) |
||||
err := &pb.Error{ErrCode: 2233, ErrMessage: "message", ErrDetail: detail} |
||||
gst := FromError(err) |
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
assert.Len(t, gst.Details(), 1) |
||||
assert.Equal(t, "2233", gst.Message()) |
||||
}) |
||||
t.Run("input ecode.Status", func(t *testing.T) { |
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
err, _ := ecode.Error(ecode.Unauthorized, "unauthorized").WithDetails(m) |
||||
gst := FromError(err) |
||||
|
||||
//assert.Equal(t, codes.Unauthenticated, gst.Code())
|
||||
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
|
||||
assert.Equal(t, codes.Unknown, gst.Code()) |
||||
assert.Len(t, gst.Details(), 2) |
||||
details := gst.Details() |
||||
assert.IsType(t, &pb.Error{}, details[0]) |
||||
assert.IsType(t, err.Proto(), details[1]) |
||||
}) |
||||
} |
||||
|
||||
func TestToEcode(t *testing.T) { |
||||
t.Run("input general grpc.Status", func(t *testing.T) { |
||||
gst := status.New(codes.Unknown, "unknown") |
||||
ec := ToEcode(gst) |
||||
|
||||
assert.Equal(t, int(ecode.ServerErr), ec.Code()) |
||||
assert.Equal(t, "-500", ec.Message()) |
||||
assert.Len(t, ec.Details(), 0) |
||||
}) |
||||
|
||||
t.Run("input pb.Error", func(t *testing.T) { |
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
detail, _ := ptypes.MarshalAny(m) |
||||
gst := status.New(codes.InvalidArgument, "requesterr") |
||||
gst, _ = gst.WithDetails(&pb.Error{ErrCode: 1122, ErrMessage: "message", ErrDetail: detail}) |
||||
ec := ToEcode(gst) |
||||
|
||||
assert.Equal(t, 1122, ec.Code()) |
||||
assert.Equal(t, "message", ec.Message()) |
||||
assert.Len(t, ec.Details(), 1) |
||||
assert.IsType(t, m, ec.Details()[0]) |
||||
}) |
||||
|
||||
t.Run("input pb.Error and ecode.Status", func(t *testing.T) { |
||||
gst := status.New(codes.InvalidArgument, "requesterr") |
||||
gst, _ = gst.WithDetails( |
||||
&pb.Error{ErrCode: 1122, ErrMessage: "message"}, |
||||
ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").Proto(), |
||||
) |
||||
ec := ToEcode(gst) |
||||
|
||||
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code()) |
||||
assert.Equal(t, "AccessKeyErr", ec.Message()) |
||||
}) |
||||
|
||||
t.Run("input encode.Status", func(t *testing.T) { |
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()} |
||||
st, _ := ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").WithDetails(m) |
||||
gst := status.New(codes.InvalidArgument, "requesterr") |
||||
gst, _ = gst.WithDetails(st.Proto()) |
||||
ec := ToEcode(gst) |
||||
|
||||
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code()) |
||||
assert.Equal(t, "AccessKeyErr", ec.Message()) |
||||
assert.Len(t, ec.Details(), 1) |
||||
assert.IsType(t, m, ec.Details()[0]) |
||||
}) |
||||
} |
@ -0,0 +1,118 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strconv" |
||||
"time" |
||||
|
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/peer" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/net/metadata" |
||||
"github.com/bilibili/kratos/pkg/stat" |
||||
) |
||||
|
||||
var ( |
||||
statsClient = stat.RPCClient |
||||
statsServer = stat.RPCServer |
||||
) |
||||
|
||||
func logFn(code int, dt time.Duration) func(context.Context, ...log.D) { |
||||
switch { |
||||
case code < 0: |
||||
return log.Errorv |
||||
case dt >= time.Millisecond*500: |
||||
// TODO: slowlog make it configurable.
|
||||
return log.Warnv |
||||
case code > 0: |
||||
return log.Warnv |
||||
} |
||||
return log.Infov |
||||
} |
||||
|
||||
// clientLogging warden grpc logging
|
||||
func clientLogging() grpc.UnaryClientInterceptor { |
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { |
||||
startTime := time.Now() |
||||
var peerInfo peer.Peer |
||||
opts = append(opts, grpc.Peer(&peerInfo)) |
||||
|
||||
// invoker requests
|
||||
err := invoker(ctx, method, req, reply, cc, opts...) |
||||
|
||||
// after request
|
||||
code := ecode.Cause(err).Code() |
||||
duration := time.Since(startTime) |
||||
// monitor
|
||||
statsClient.Timing(method, int64(duration/time.Millisecond)) |
||||
statsClient.Incr(method, strconv.Itoa(code)) |
||||
|
||||
var ip string |
||||
if peerInfo.Addr != nil { |
||||
ip = peerInfo.Addr.String() |
||||
} |
||||
logFields := []log.D{ |
||||
log.KVString("ip", ip), |
||||
log.KVString("path", method), |
||||
log.KVInt("ret", code), |
||||
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
|
||||
log.KVString("args", req.(fmt.Stringer).String()), |
||||
log.KVFloat64("ts", duration.Seconds()), |
||||
log.KVString("source", "grpc-access-log"), |
||||
} |
||||
if err != nil { |
||||
logFields = append(logFields, log.KV("error", err.Error()), log.KVString("stack", fmt.Sprintf("%+v", err))) |
||||
} |
||||
logFn(code, duration)(ctx, logFields...) |
||||
return err |
||||
} |
||||
} |
||||
|
||||
// serverLogging warden grpc logging
|
||||
func serverLogging() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
startTime := time.Now() |
||||
caller := metadata.String(ctx, metadata.Caller) |
||||
if caller == "" { |
||||
caller = "no_user" |
||||
} |
||||
var remoteIP string |
||||
if peerInfo, ok := peer.FromContext(ctx); ok { |
||||
remoteIP = peerInfo.Addr.String() |
||||
} |
||||
var quota float64 |
||||
if deadline, ok := ctx.Deadline(); ok { |
||||
quota = time.Until(deadline).Seconds() |
||||
} |
||||
|
||||
// call server handler
|
||||
resp, err := handler(ctx, req) |
||||
|
||||
// after server response
|
||||
code := ecode.Cause(err).Code() |
||||
duration := time.Since(startTime) |
||||
|
||||
// monitor
|
||||
statsServer.Timing(caller, int64(duration/time.Millisecond), info.FullMethod) |
||||
statsServer.Incr(caller, info.FullMethod, strconv.Itoa(code)) |
||||
logFields := []log.D{ |
||||
log.KVString("user", caller), |
||||
log.KVString("ip", remoteIP), |
||||
log.KVString("path", info.FullMethod), |
||||
log.KVInt("ret", code), |
||||
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
|
||||
log.KVString("args", req.(fmt.Stringer).String()), |
||||
log.KVFloat64("ts", duration.Seconds()), |
||||
log.KVFloat64("timeout_quota", quota), |
||||
log.KVString("source", "grpc-access-log"), |
||||
} |
||||
if err != nil { |
||||
logFields = append(logFields, log.KV("error", err.Error()), log.KV("stack", fmt.Sprintf("%+v", err))) |
||||
} |
||||
logFn(code, duration)(ctx, logFields...) |
||||
return resp, err |
||||
} |
||||
} |
@ -0,0 +1,55 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/log" |
||||
) |
||||
|
||||
func Test_logFn(t *testing.T) { |
||||
type args struct { |
||||
code int |
||||
dt time.Duration |
||||
} |
||||
tests := []struct { |
||||
name string |
||||
args args |
||||
want func(context.Context, ...log.D) |
||||
}{ |
||||
{ |
||||
name: "ok", |
||||
args: args{code: 0, dt: time.Millisecond}, |
||||
want: log.Infov, |
||||
}, |
||||
{ |
||||
name: "slowlog", |
||||
args: args{code: 0, dt: time.Second}, |
||||
want: log.Warnv, |
||||
}, |
||||
{ |
||||
name: "business error", |
||||
args: args{code: 2233, dt: time.Millisecond}, |
||||
want: log.Warnv, |
||||
}, |
||||
{ |
||||
name: "system error", |
||||
args: args{code: -1, dt: 0}, |
||||
want: log.Errorv, |
||||
}, |
||||
{ |
||||
name: "system error and slowlog", |
||||
args: args{code: -1, dt: time.Second}, |
||||
want: log.Errorv, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
if got := logFn(tt.args.code, tt.args.dt); reflect.ValueOf(got).Pointer() != reflect.ValueOf(tt.want).Pointer() { |
||||
t.Errorf("unexpect log function!") |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,61 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"runtime" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
|
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
// recovery is a server interceptor that recovers from any panics.
|
||||
func (s *Server) recovery() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
||||
defer func() { |
||||
if rerr := recover(); rerr != nil { |
||||
const size = 64 << 10 |
||||
buf := make([]byte, size) |
||||
rs := runtime.Stack(buf, false) |
||||
if rs > size { |
||||
rs = size |
||||
} |
||||
buf = buf[:rs] |
||||
pl := fmt.Sprintf("grpc server panic: %v\n%v\n%s\n", req, rerr, buf) |
||||
fmt.Fprintf(os.Stderr, pl) |
||||
log.Error(pl) |
||||
err = status.Errorf(codes.Unknown, ecode.ServerErr.Error()) |
||||
} |
||||
}() |
||||
resp, err = handler(ctx, req) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// recovery return a client interceptor that recovers from any panics.
|
||||
func (c *Client) recovery() grpc.UnaryClientInterceptor { |
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { |
||||
defer func() { |
||||
if rerr := recover(); rerr != nil { |
||||
const size = 64 << 10 |
||||
buf := make([]byte, size) |
||||
rs := runtime.Stack(buf, false) |
||||
if rs > size { |
||||
rs = size |
||||
} |
||||
buf = buf[:rs] |
||||
pl := fmt.Sprintf("grpc client panic: %v\n%v\n%v\n%s\n", req, reply, rerr, buf) |
||||
fmt.Fprintf(os.Stderr, pl) |
||||
log.Error(pl) |
||||
err = ecode.ServerErr |
||||
} |
||||
}() |
||||
err = invoker(ctx, method, req, reply, cc, opts...) |
||||
return |
||||
} |
||||
} |
@ -0,0 +1,17 @@ |
||||
### business/warden/resolver |
||||
|
||||
##### Version 1.1.1 |
||||
1. add dial helper |
||||
|
||||
##### Version 1.1.0 |
||||
1. 增加了子集选择算法 |
||||
|
||||
##### Version 1.0.2 |
||||
1. 增加GET接口 |
||||
|
||||
##### Version 1.0.1 |
||||
1. 支持zone和clusters |
||||
|
||||
|
||||
##### Version 1.0.0 |
||||
1. 实现了基本的服务发现功能 |
@ -0,0 +1,9 @@ |
||||
# See the OWNERS docs at https://go.k8s.io/owners |
||||
|
||||
approvers: |
||||
- caoguoliang |
||||
labels: |
||||
- library |
||||
reviewers: |
||||
- caoguoliang |
||||
- maojian |
@ -0,0 +1,13 @@ |
||||
#### business/warden/resolver |
||||
|
||||
##### 项目简介 |
||||
|
||||
warden 的 服务发现模块,用于从底层的注册中心中获取Server节点列表并返回给GRPC |
||||
|
||||
##### 编译环境 |
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行** |
||||
|
||||
##### 依赖包 |
||||
|
||||
- [grpc](google.golang.org/grpc) |
@ -0,0 +1,6 @@ |
||||
### business/warden/resolver/direct |
||||
|
||||
|
||||
##### Version 1.0.0 |
||||
|
||||
1. 实现了基本的服务发现直连功能 |
@ -0,0 +1,77 @@ |
||||
package direct |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver" |
||||
) |
||||
|
||||
const ( |
||||
// Name is the name of direct resolver
|
||||
Name = "direct" |
||||
) |
||||
|
||||
var _ naming.Resolver = &Direct{} |
||||
|
||||
// New return Direct
|
||||
func New() *Direct { |
||||
return &Direct{} |
||||
} |
||||
|
||||
// Build build direct.
|
||||
func Build(id string) *Direct { |
||||
return &Direct{id: id} |
||||
} |
||||
|
||||
// Direct is a resolver for conneting endpoints directly.
|
||||
// example format: direct://default/192.168.1.1:8080,192.168.1.2:8081
|
||||
type Direct struct { |
||||
id string |
||||
} |
||||
|
||||
// Build direct build.
|
||||
func (d *Direct) Build(id string) naming.Resolver { |
||||
return &Direct{id: id} |
||||
} |
||||
|
||||
// Scheme return the Scheme of Direct
|
||||
func (d *Direct) Scheme() string { |
||||
return Name |
||||
} |
||||
|
||||
// Watch a tree
|
||||
func (d *Direct) Watch() <-chan struct{} { |
||||
ch := make(chan struct{}, 1) |
||||
ch <- struct{}{} |
||||
return ch |
||||
} |
||||
|
||||
//Unwatch a tree
|
||||
func (d *Direct) Unwatch(id string) { |
||||
} |
||||
|
||||
//Fetch fetch isntances
|
||||
func (d *Direct) Fetch(ctx context.Context) (insMap map[string][]*naming.Instance, found bool) { |
||||
var ins []*naming.Instance |
||||
|
||||
addrs := strings.Split(d.id, ",") |
||||
for _, addr := range addrs { |
||||
ins = append(ins, &naming.Instance{ |
||||
Addrs: []string{fmt.Sprintf("%s://%s", resolver.Scheme, addr)}, |
||||
}) |
||||
} |
||||
if len(ins) > 0 { |
||||
found = true |
||||
} |
||||
insMap = map[string][]*naming.Instance{env.Zone: ins} |
||||
return |
||||
} |
||||
|
||||
//Close close Direct
|
||||
func (d *Direct) Close() error { |
||||
return nil |
||||
} |
@ -0,0 +1,85 @@ |
||||
package direct |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
) |
||||
|
||||
type testServer struct { |
||||
name string |
||||
} |
||||
|
||||
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
return &pb.HelloReply{Message: ts.name, Success: true}, nil |
||||
} |
||||
|
||||
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
panic("not implement error") |
||||
} |
||||
|
||||
func createServer(name, listen string) *warden.Server { |
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)}) |
||||
ts := &testServer{name} |
||||
pb.RegisterGreeterServer(s.Server(), ts) |
||||
go func() { |
||||
if err := s.Run(listen); err != nil { |
||||
panic(fmt.Sprintf("run warden server fail! err: %s", err)) |
||||
} |
||||
}() |
||||
return s |
||||
} |
||||
|
||||
func TestMain(m *testing.M) { |
||||
resolver.Register(New()) |
||||
ctx := context.TODO() |
||||
s1 := createServer("server1", "127.0.0.1:18081") |
||||
s2 := createServer("server2", "127.0.0.1:18082") |
||||
defer s1.Shutdown(ctx) |
||||
defer s2.Shutdown(ctx) |
||||
os.Exit(m.Run()) |
||||
} |
||||
|
||||
func createTestClient(t *testing.T, connStr string) pb.GreeterClient { |
||||
client := warden.NewClient(&warden.ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
}) |
||||
conn, err := client.Dial(context.TODO(), connStr) |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
return pb.NewGreeterClient(conn) |
||||
} |
||||
|
||||
func TestDirect(t *testing.T) { |
||||
cli := createTestClient(t, "direct://default/127.0.0.1:18083,127.0.0.1:18082") |
||||
count := 0 |
||||
for i := 0; i < 10; i++ { |
||||
if resp, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("TestDirect: SayHello failed!err:=%v", err) |
||||
} else { |
||||
if resp.Message == "server2" { |
||||
count++ |
||||
} |
||||
} |
||||
} |
||||
if count != 10 { |
||||
t.Fatalf("TestDirect: get server2 times must be 10") |
||||
} |
||||
} |
@ -0,0 +1,204 @@ |
||||
package resolver |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"math/rand" |
||||
"net/url" |
||||
"os" |
||||
"sort" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
|
||||
"github.com/dgryski/go-farm" |
||||
"github.com/pkg/errors" |
||||
"google.golang.org/grpc/resolver" |
||||
) |
||||
|
||||
const ( |
||||
// Scheme is the scheme of discovery address
|
||||
Scheme = "grpc" |
||||
) |
||||
|
||||
var ( |
||||
_ resolver.Resolver = &Resolver{} |
||||
_ resolver.Builder = &Builder{} |
||||
mu sync.Mutex |
||||
) |
||||
|
||||
// Register register resolver builder if nil.
|
||||
func Register(b naming.Builder) { |
||||
mu.Lock() |
||||
defer mu.Unlock() |
||||
if resolver.Get(b.Scheme()) == nil { |
||||
resolver.Register(&Builder{b}) |
||||
} |
||||
} |
||||
|
||||
// Set override any registered builder
|
||||
func Set(b naming.Builder) { |
||||
mu.Lock() |
||||
defer mu.Unlock() |
||||
resolver.Register(&Builder{b}) |
||||
} |
||||
|
||||
// Builder is also a resolver builder.
|
||||
// It's build() function always returns itself.
|
||||
type Builder struct { |
||||
naming.Builder |
||||
} |
||||
|
||||
// Build returns itself for Resolver, because it's both a builder and a resolver.
|
||||
func (b *Builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { |
||||
var zone = env.Zone |
||||
ss := int64(50) |
||||
clusters := map[string]struct{}{} |
||||
str := strings.SplitN(target.Endpoint, "?", 2) |
||||
if len(str) == 0 { |
||||
return nil, errors.Errorf("warden resolver: parse target.Endpoint(%s) failed!err:=endpoint is empty", target.Endpoint) |
||||
} else if len(str) == 2 { |
||||
m, err := url.ParseQuery(str[1]) |
||||
if err == nil { |
||||
for _, c := range m[naming.MetaCluster] { |
||||
clusters[c] = struct{}{} |
||||
} |
||||
zones := m[naming.MetaZone] |
||||
if len(zones) > 0 { |
||||
zone = zones[0] |
||||
} |
||||
if sub, ok := m["subset"]; ok { |
||||
if t, err := strconv.ParseInt(sub[0], 10, 64); err == nil { |
||||
ss = t |
||||
} |
||||
|
||||
} |
||||
} |
||||
} |
||||
r := &Resolver{ |
||||
nr: b.Builder.Build(str[0]), |
||||
cc: cc, |
||||
quit: make(chan struct{}, 1), |
||||
clusters: clusters, |
||||
zone: zone, |
||||
subsetSize: ss, |
||||
} |
||||
go r.updateproc() |
||||
return r, nil |
||||
} |
||||
|
||||
// Resolver watches for the updates on the specified target.
|
||||
// Updates include address updates and service config updates.
|
||||
type Resolver struct { |
||||
nr naming.Resolver |
||||
cc resolver.ClientConn |
||||
quit chan struct{} |
||||
|
||||
clusters map[string]struct{} |
||||
zone string |
||||
subsetSize int64 |
||||
} |
||||
|
||||
// Close is a noop for Resolver.
|
||||
func (r *Resolver) Close() { |
||||
select { |
||||
case r.quit <- struct{}{}: |
||||
r.nr.Close() |
||||
default: |
||||
} |
||||
} |
||||
|
||||
// ResolveNow is a noop for Resolver.
|
||||
func (r *Resolver) ResolveNow(o resolver.ResolveNowOption) { |
||||
} |
||||
|
||||
func (r *Resolver) updateproc() { |
||||
event := r.nr.Watch() |
||||
for { |
||||
select { |
||||
case <-r.quit: |
||||
return |
||||
case _, ok := <-event: |
||||
if !ok { |
||||
return |
||||
} |
||||
} |
||||
if insInfo, ok := r.nr.Fetch(context.Background()); ok { |
||||
instances, ok := insInfo.Instances[r.zone] |
||||
if !ok { |
||||
for _, value := range insInfo.Instances { |
||||
instances = append(instances, value...) |
||||
} |
||||
} |
||||
if r.subsetSize > 0 && len(instances) > 0 { |
||||
instances = r.subset(instances, env.Hostname, r.subsetSize) |
||||
} |
||||
r.newAddress(instances) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (r *Resolver) subset(backends []*naming.Instance, clientID string, size int64) []*naming.Instance { |
||||
if len(backends) <= int(size) { |
||||
return backends |
||||
} |
||||
sort.Slice(backends, func(i, j int) bool { |
||||
return backends[i].Hostname < backends[j].Hostname |
||||
}) |
||||
count := int64(len(backends)) / size |
||||
|
||||
id := farm.Fingerprint64([]byte(clientID)) |
||||
round := int64(id / uint64(count)) |
||||
|
||||
s := rand.NewSource(round) |
||||
ra := rand.New(s) |
||||
ra.Shuffle(len(backends), func(i, j int) { |
||||
backends[i], backends[j] = backends[j], backends[i] |
||||
}) |
||||
start := (id % uint64(count)) * uint64(size) |
||||
return backends[int(start) : int(start)+int(size)] |
||||
} |
||||
|
||||
func (r *Resolver) newAddress(instances []*naming.Instance) { |
||||
if len(instances) <= 0 { |
||||
return |
||||
} |
||||
addrs := make([]resolver.Address, 0, len(instances)) |
||||
for _, ins := range instances { |
||||
if len(r.clusters) > 0 { |
||||
if _, ok := r.clusters[ins.Metadata[naming.MetaCluster]]; !ok { |
||||
continue |
||||
} |
||||
} |
||||
|
||||
var weight int64 |
||||
if weight, _ = strconv.ParseInt(ins.Metadata[naming.MetaWeight], 10, 64); weight <= 0 { |
||||
weight = 10 |
||||
} |
||||
var rpc string |
||||
for _, a := range ins.Addrs { |
||||
u, err := url.Parse(a) |
||||
if err == nil && u.Scheme == Scheme { |
||||
rpc = u.Host |
||||
} |
||||
} |
||||
if rpc == "" { |
||||
fmt.Fprintf(os.Stderr, "warden/resolver: app(%s,%s) no valid grpc address(%v) found!", ins.AppID, ins.Hostname, ins.Addrs) |
||||
log.Warn("warden/resolver: invalid rpc address(%s,%s,%v) found!", ins.AppID, ins.Hostname, ins.Addrs) |
||||
continue |
||||
} |
||||
addr := resolver.Address{ |
||||
Addr: rpc, |
||||
Type: resolver.Backend, |
||||
ServerName: ins.AppID, |
||||
Metadata: wmeta.MD{Weight: uint64(weight), Color: ins.Metadata[naming.MetaColor]}, |
||||
} |
||||
addrs = append(addrs, addr) |
||||
} |
||||
r.cc.NewAddress(addrs) |
||||
} |
@ -0,0 +1,87 @@ |
||||
package resolver |
||||
|
||||
import ( |
||||
"context" |
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
) |
||||
|
||||
type mockDiscoveryBuilder struct { |
||||
instances map[string]*naming.Instance |
||||
watchch map[string][]*mockDiscoveryResolver |
||||
} |
||||
|
||||
func (mb *mockDiscoveryBuilder) Build(id string) naming.Resolver { |
||||
mr := &mockDiscoveryResolver{ |
||||
d: mb, |
||||
watchch: make(chan struct{}, 1), |
||||
} |
||||
mb.watchch[id] = append(mb.watchch[id], mr) |
||||
mr.watchch <- struct{}{} |
||||
return mr |
||||
} |
||||
func (mb *mockDiscoveryBuilder) Scheme() string { |
||||
return "mockdiscovery" |
||||
} |
||||
|
||||
type mockDiscoveryResolver struct { |
||||
//instances map[string]*naming.Instance
|
||||
d *mockDiscoveryBuilder |
||||
watchch chan struct{} |
||||
} |
||||
|
||||
var _ naming.Resolver = &mockDiscoveryResolver{} |
||||
|
||||
func (md *mockDiscoveryResolver) Fetch(ctx context.Context) (map[string][]*naming.Instance, bool) { |
||||
zones := make(map[string][]*naming.Instance) |
||||
for _, v := range md.d.instances { |
||||
zones[v.Zone] = append(zones[v.Zone], v) |
||||
} |
||||
return zones, len(zones) > 0 |
||||
} |
||||
|
||||
func (md *mockDiscoveryResolver) Watch() <-chan struct{} { |
||||
return md.watchch |
||||
} |
||||
|
||||
func (md *mockDiscoveryResolver) Close() error { |
||||
close(md.watchch) |
||||
return nil |
||||
} |
||||
|
||||
func (md *mockDiscoveryResolver) Scheme() string { |
||||
return "mockdiscovery" |
||||
} |
||||
|
||||
func (mb *mockDiscoveryBuilder) registry(appID string, hostname, rpc string, metadata map[string]string) { |
||||
ins := &naming.Instance{ |
||||
AppID: appID, |
||||
Env: "hello=world", |
||||
Hostname: hostname, |
||||
Addrs: []string{"grpc://" + rpc}, |
||||
Version: "1.1", |
||||
Zone: env.Zone, |
||||
Metadata: metadata, |
||||
} |
||||
mb.instances[hostname] = ins |
||||
if ch, ok := mb.watchch[appID]; ok { |
||||
var bullet struct{} |
||||
for _, c := range ch { |
||||
c.watchch <- bullet |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (mb *mockDiscoveryBuilder) cancel(hostname string) { |
||||
ins, ok := mb.instances[hostname] |
||||
if !ok { |
||||
return |
||||
} |
||||
delete(mb.instances, hostname) |
||||
if ch, ok := mb.watchch[ins.AppID]; ok { |
||||
var bullet struct{} |
||||
for _, c := range ch { |
||||
c.watchch <- bullet |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,312 @@ |
||||
package resolver |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env" |
||||
"github.com/bilibili/kratos/pkg/naming" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
var testServerMap map[string]*testServer |
||||
|
||||
func init() { |
||||
testServerMap = make(map[string]*testServer) |
||||
} |
||||
|
||||
const testAppID = "main.test" |
||||
|
||||
type testServer struct { |
||||
SayHelloCount int |
||||
} |
||||
|
||||
func resetCount() { |
||||
for _, s := range testServerMap { |
||||
s.SayHelloCount = 0 |
||||
} |
||||
} |
||||
|
||||
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
ts.SayHelloCount++ |
||||
return &pb.HelloReply{Message: "hello", Success: true}, nil |
||||
} |
||||
|
||||
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
panic("not implement error") |
||||
} |
||||
|
||||
func createServer(name, listen string) *warden.Server { |
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)}) |
||||
ts := &testServer{} |
||||
testServerMap[name] = ts |
||||
pb.RegisterGreeterServer(s.Server(), ts) |
||||
go func() { |
||||
if err := s.Run(listen); err != nil { |
||||
panic(fmt.Sprintf("run warden server fail! err: %s", err)) |
||||
} |
||||
}() |
||||
return s |
||||
} |
||||
|
||||
func NSayHello(c pb.GreeterClient, n int) func(*testing.T) { |
||||
return func(t *testing.T) { |
||||
for i := 0; i < n; i++ { |
||||
if _, err := c.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func createTestClient(t *testing.T) pb.GreeterClient { |
||||
client := warden.NewClient(&warden.ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
}) |
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test") |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
return pb.NewGreeterClient(conn) |
||||
} |
||||
|
||||
var mockResolver *mockDiscoveryBuilder |
||||
|
||||
func newMockDiscoveryBuilder() *mockDiscoveryBuilder { |
||||
return &mockDiscoveryBuilder{ |
||||
instances: make(map[string]*naming.Instance), |
||||
watchch: make(map[string][]*mockDiscoveryResolver), |
||||
} |
||||
} |
||||
func TestMain(m *testing.M) { |
||||
ctx := context.TODO() |
||||
mockResolver = newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
s1 := createServer("server1", "127.0.0.1:18081") |
||||
s2 := createServer("server2", "127.0.0.1:18082") |
||||
s3 := createServer("server3", "127.0.0.1:18083") |
||||
s4 := createServer("server4", "127.0.0.1:18084") |
||||
s5 := createServer("server5", "127.0.0.1:18085") |
||||
defer s1.Shutdown(ctx) |
||||
defer s2.Shutdown(ctx) |
||||
defer s3.Shutdown(ctx) |
||||
defer s4.Shutdown(ctx) |
||||
defer s5.Shutdown(ctx) |
||||
os.Exit(m.Run()) |
||||
} |
||||
|
||||
func TestAddResolver(t *testing.T) { |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
c := createTestClient(t) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount) |
||||
resetCount() |
||||
} |
||||
|
||||
func TestDeleteResolver(t *testing.T) { |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{}) |
||||
c := createTestClient(t) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount) |
||||
|
||||
mockResolver.cancel("server1") |
||||
resetCount() |
||||
time.Sleep(time.Millisecond * 10) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestUpdateResolver(t *testing.T) { |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{}) |
||||
|
||||
c := createTestClient(t) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount) |
||||
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18083", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18084", map[string]string{}) |
||||
resetCount() |
||||
time.Sleep(time.Millisecond * 10) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 10, testServerMap["server3"].SayHelloCount+testServerMap["server4"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestErrorResolver(t *testing.T) { |
||||
mockResolver := newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server6", "127.0.0.1:18086", map[string]string{}) |
||||
|
||||
c := createTestClient(t) |
||||
t.Run("test_say_hello", NSayHello(c, 10)) |
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestClusterResolver(t *testing.T) { |
||||
mockResolver := newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"}) |
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"}) |
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server5", "127.0.0.1:18084", map[string]string{}) |
||||
|
||||
client := warden.NewClient(&warden.ClientConfig{Clusters: []string{"c1"}}) |
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?cluster=c2") |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
time.Sleep(time.Millisecond * 10) |
||||
cli := pb.NewGreeterClient(conn) |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount) |
||||
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestNoClusterResolver(t *testing.T) { |
||||
mockResolver := newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"}) |
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"}) |
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{}) |
||||
client := warden.NewClient(&warden.ClientConfig{}) |
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test") |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
time.Sleep(time.Millisecond * 20) |
||||
cli := pb.NewGreeterClient(conn) |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount) |
||||
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount) |
||||
assert.Equal(t, 1, testServerMap["server4"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestZoneResolver(t *testing.T) { |
||||
mockResolver := newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
env.Zone = "testsh" |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{}) |
||||
env.Zone = "hhhh" |
||||
client := warden.NewClient(&warden.ClientConfig{Zone: "testsh"}) |
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test") |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
time.Sleep(time.Millisecond * 10) |
||||
cli := pb.NewGreeterClient(conn) |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount) |
||||
assert.Equal(t, 3, testServerMap["server2"].SayHelloCount) |
||||
|
||||
resetCount() |
||||
} |
||||
|
||||
func TestSubsetConn(t *testing.T) { |
||||
mockResolver := newMockDiscoveryBuilder() |
||||
resolver.Set(mockResolver) |
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{}) |
||||
mockResolver.registry(testAppID, "server5", "127.0.0.1:18085", map[string]string{}) |
||||
|
||||
client := warden.NewClient(nil) |
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?subset=3") |
||||
if err != nil { |
||||
t.Fatalf("create client fail!err%s", err) |
||||
} |
||||
time.Sleep(time.Millisecond * 20) |
||||
cli := pb.NewGreeterClient(conn) |
||||
for i := 0; i < 6; i++ { |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
} |
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount) |
||||
resetCount() |
||||
mockResolver.cancel("server4") |
||||
time.Sleep(time.Millisecond * 20) |
||||
for i := 0; i < 6; i++ { |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
} |
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server3"].SayHelloCount) |
||||
resetCount() |
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{}) |
||||
time.Sleep(time.Millisecond * 20) |
||||
for i := 0; i < 6; i++ { |
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil { |
||||
t.Fatalf("call sayhello fail! err: %s", err) |
||||
} |
||||
} |
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount) |
||||
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount) |
||||
} |
@ -0,0 +1,16 @@ |
||||
package resolver |
||||
|
||||
import ( |
||||
"flag" |
||||
"fmt" |
||||
) |
||||
|
||||
// RegisterTarget will register grpc discovery mock address flag
|
||||
func RegisterTarget(target *string, discoveryID string) { |
||||
flag.CommandLine.StringVar( |
||||
target, |
||||
fmt.Sprintf("grpc.%s", discoveryID), |
||||
fmt.Sprintf("discovery://default/%s", discoveryID), |
||||
fmt.Sprintf("App's grpc target.\n example: -grpc.%s=\"127.0.0.1:9090\"", discoveryID), |
||||
) |
||||
} |
@ -0,0 +1,332 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
"math" |
||||
"net" |
||||
"os" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/dsn" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
"github.com/bilibili/kratos/pkg/net/trace" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
//this package is for json format response
|
||||
_ "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/encoding/json" |
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status" |
||||
|
||||
"github.com/pkg/errors" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/keepalive" |
||||
"google.golang.org/grpc/metadata" |
||||
"google.golang.org/grpc/peer" |
||||
"google.golang.org/grpc/reflection" |
||||
) |
||||
|
||||
var ( |
||||
_grpcDSN string |
||||
_defaultSerConf = &ServerConfig{ |
||||
Network: "tcp", |
||||
Addr: "0.0.0.0:9000", |
||||
Timeout: xtime.Duration(time.Second), |
||||
IdleTimeout: xtime.Duration(time.Second * 60), |
||||
MaxLifeTime: xtime.Duration(time.Hour * 2), |
||||
ForceCloseWait: xtime.Duration(time.Second * 20), |
||||
KeepAliveInterval: xtime.Duration(time.Second * 60), |
||||
KeepAliveTimeout: xtime.Duration(time.Second * 20), |
||||
} |
||||
_abortIndex int8 = math.MaxInt8 / 2 |
||||
) |
||||
|
||||
// ServerConfig is rpc server conf.
|
||||
type ServerConfig struct { |
||||
// Network is grpc listen network,default value is tcp
|
||||
Network string `dsn:"network"` |
||||
// Addr is grpc listen addr,default value is 0.0.0.0:9000
|
||||
Addr string `dsn:"address"` |
||||
// Timeout is context timeout for per rpc call.
|
||||
Timeout xtime.Duration `dsn:"query.timeout"` |
||||
// IdleTimeout is a duration for the amount of time after which an idle connection would be closed by sending a GoAway.
|
||||
// Idleness duration is defined since the most recent time the number of outstanding RPCs became zero or the connection establishment.
|
||||
IdleTimeout xtime.Duration `dsn:"query.idleTimeout"` |
||||
// MaxLifeTime is a duration for the maximum amount of time a connection may exist before it will be closed by sending a GoAway.
|
||||
// A random jitter of +/-10% will be added to MaxConnectionAge to spread out connection storms.
|
||||
MaxLifeTime xtime.Duration `dsn:"query.maxLife"` |
||||
// ForceCloseWait is an additive period after MaxLifeTime after which the connection will be forcibly closed.
|
||||
ForceCloseWait xtime.Duration `dsn:"query.closeWait"` |
||||
// KeepAliveInterval is after a duration of this time if the server doesn't see any activity it pings the client to see if the transport is still alive.
|
||||
KeepAliveInterval xtime.Duration `dsn:"query.keepaliveInterval"` |
||||
// KeepAliveTimeout is After having pinged for keepalive check, the server waits for a duration of Timeout and if no activity is seen even after that
|
||||
// the connection is closed.
|
||||
KeepAliveTimeout xtime.Duration `dsn:"query.keepaliveTimeout"` |
||||
} |
||||
|
||||
// Server is the framework's server side instance, it contains the GrpcServer, interceptor and interceptors.
|
||||
// Create an instance of Server, by using NewServer().
|
||||
type Server struct { |
||||
conf *ServerConfig |
||||
mutex sync.RWMutex |
||||
|
||||
server *grpc.Server |
||||
handlers []grpc.UnaryServerInterceptor |
||||
} |
||||
|
||||
// handle return a new unary server interceptor for OpenTracing\Logging\LinkTimeout.
|
||||
func (s *Server) handle() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
||||
var ( |
||||
cancel func() |
||||
addr string |
||||
) |
||||
s.mutex.RLock() |
||||
conf := s.conf |
||||
s.mutex.RUnlock() |
||||
// get derived timeout from grpc context,
|
||||
// compare with the warden configured,
|
||||
// and use the minimum one
|
||||
timeout := time.Duration(conf.Timeout) |
||||
if dl, ok := ctx.Deadline(); ok { |
||||
ctimeout := time.Until(dl) |
||||
if ctimeout-time.Millisecond*20 > 0 { |
||||
ctimeout = ctimeout - time.Millisecond*20 |
||||
} |
||||
if timeout > ctimeout { |
||||
timeout = ctimeout |
||||
} |
||||
} |
||||
ctx, cancel = context.WithTimeout(ctx, timeout) |
||||
defer cancel() |
||||
|
||||
// get grpc metadata(trace & remote_ip & color)
|
||||
var t trace.Trace |
||||
cmd := nmd.MD{} |
||||
if gmd, ok := metadata.FromIncomingContext(ctx); ok { |
||||
for key, vals := range gmd { |
||||
if nmd.IsIncomingKey(key) { |
||||
cmd[key] = vals[0] |
||||
} |
||||
} |
||||
} |
||||
if t == nil { |
||||
t = trace.New(args.FullMethod) |
||||
} else { |
||||
t.SetTitle(args.FullMethod) |
||||
} |
||||
|
||||
if pr, ok := peer.FromContext(ctx); ok { |
||||
addr = pr.Addr.String() |
||||
t.SetTag(trace.String(trace.TagAddress, addr)) |
||||
} |
||||
defer t.Finish(&err) |
||||
|
||||
// use common meta data context instead of grpc context
|
||||
ctx = nmd.NewContext(ctx, cmd) |
||||
ctx = trace.NewContext(ctx, t) |
||||
|
||||
resp, err = handler(ctx, req) |
||||
return resp, status.FromError(err).Err() |
||||
} |
||||
} |
||||
|
||||
func init() { |
||||
addFlag(flag.CommandLine) |
||||
} |
||||
|
||||
func addFlag(fs *flag.FlagSet) { |
||||
v := os.Getenv("GRPC") |
||||
if v == "" { |
||||
v = "tcp://0.0.0.0:9000/?timeout=1s&idle_timeout=60s" |
||||
} |
||||
fs.StringVar(&_grpcDSN, "grpc", v, "listen grpc dsn, or use GRPC env variable.") |
||||
fs.Var(&_grpcTarget, "grpc.target", "usage: -grpc.target=seq.service=127.0.0.1:9000 -grpc.target=fav.service=192.168.10.1:9000") |
||||
} |
||||
|
||||
func parseDSN(rawdsn string) *ServerConfig { |
||||
conf := new(ServerConfig) |
||||
d, err := dsn.Parse(rawdsn) |
||||
if err != nil { |
||||
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn))) |
||||
} |
||||
if _, err = d.Bind(conf); err != nil { |
||||
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn))) |
||||
} |
||||
return conf |
||||
} |
||||
|
||||
// NewServer returns a new blank Server instance with a default server interceptor.
|
||||
func NewServer(conf *ServerConfig, opt ...grpc.ServerOption) (s *Server) { |
||||
if conf == nil { |
||||
if !flag.Parsed() { |
||||
fmt.Fprint(os.Stderr, "[warden] please call flag.Parse() before Init warden server, some configure may not effect\n") |
||||
} |
||||
conf = parseDSN(_grpcDSN) |
||||
} |
||||
s = new(Server) |
||||
if err := s.SetConfig(conf); err != nil { |
||||
panic(errors.Errorf("warden: set config failed!err: %s", err.Error())) |
||||
} |
||||
keepParam := grpc.KeepaliveParams(keepalive.ServerParameters{ |
||||
MaxConnectionIdle: time.Duration(s.conf.IdleTimeout), |
||||
MaxConnectionAgeGrace: time.Duration(s.conf.ForceCloseWait), |
||||
Time: time.Duration(s.conf.KeepAliveInterval), |
||||
Timeout: time.Duration(s.conf.KeepAliveTimeout), |
||||
MaxConnectionAge: time.Duration(s.conf.MaxLifeTime), |
||||
}) |
||||
opt = append(opt, keepParam, grpc.UnaryInterceptor(s.interceptor)) |
||||
s.server = grpc.NewServer(opt...) |
||||
s.Use(s.recovery(), s.handle(), serverLogging(), s.stats(), s.validate()) |
||||
return |
||||
} |
||||
|
||||
// SetConfig hot reloads server config
|
||||
func (s *Server) SetConfig(conf *ServerConfig) (err error) { |
||||
if conf == nil { |
||||
conf = _defaultSerConf |
||||
} |
||||
if conf.Timeout <= 0 { |
||||
conf.Timeout = xtime.Duration(time.Second) |
||||
} |
||||
if conf.IdleTimeout <= 0 { |
||||
conf.IdleTimeout = xtime.Duration(time.Second * 60) |
||||
} |
||||
if conf.MaxLifeTime <= 0 { |
||||
conf.MaxLifeTime = xtime.Duration(time.Hour * 2) |
||||
} |
||||
if conf.ForceCloseWait <= 0 { |
||||
conf.ForceCloseWait = xtime.Duration(time.Second * 20) |
||||
} |
||||
if conf.KeepAliveInterval <= 0 { |
||||
conf.KeepAliveInterval = xtime.Duration(time.Second * 60) |
||||
} |
||||
if conf.KeepAliveTimeout <= 0 { |
||||
conf.KeepAliveTimeout = xtime.Duration(time.Second * 20) |
||||
} |
||||
if conf.Addr == "" { |
||||
conf.Addr = "0.0.0.0:9000" |
||||
} |
||||
if conf.Network == "" { |
||||
conf.Network = "tcp" |
||||
} |
||||
s.mutex.Lock() |
||||
s.conf = conf |
||||
s.mutex.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
// interceptor is a single interceptor out of a chain of many interceptors.
|
||||
// Execution is done in left-to-right order, including passing of context.
|
||||
// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
|
||||
// will see context changes of one and two.
|
||||
func (s *Server) interceptor(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
var ( |
||||
i int |
||||
chain grpc.UnaryHandler |
||||
) |
||||
|
||||
n := len(s.handlers) |
||||
if n == 0 { |
||||
return handler(ctx, req) |
||||
} |
||||
|
||||
chain = func(ic context.Context, ir interface{}) (interface{}, error) { |
||||
if i == n-1 { |
||||
return handler(ic, ir) |
||||
} |
||||
i++ |
||||
return s.handlers[i](ic, ir, args, chain) |
||||
} |
||||
|
||||
return s.handlers[0](ctx, req, args, chain) |
||||
} |
||||
|
||||
// Server return the grpc server for registering service.
|
||||
func (s *Server) Server() *grpc.Server { |
||||
return s.server |
||||
} |
||||
|
||||
// Use attachs a global inteceptor to the server.
|
||||
// For example, this is the right place for a rate limiter or error management inteceptor.
|
||||
func (s *Server) Use(handlers ...grpc.UnaryServerInterceptor) *Server { |
||||
finalSize := len(s.handlers) + len(handlers) |
||||
if finalSize >= int(_abortIndex) { |
||||
panic("warden: server use too many handlers") |
||||
} |
||||
mergedHandlers := make([]grpc.UnaryServerInterceptor, finalSize) |
||||
copy(mergedHandlers, s.handlers) |
||||
copy(mergedHandlers[len(s.handlers):], handlers) |
||||
s.handlers = mergedHandlers |
||||
return s |
||||
} |
||||
|
||||
// Run create a tcp listener and start goroutine for serving each incoming request.
|
||||
// Run will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) Run(addr string) error { |
||||
lis, err := net.Listen("tcp", addr) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
log.Error("failed to listen: %v", err) |
||||
return err |
||||
} |
||||
reflection.Register(s.server) |
||||
return s.Serve(lis) |
||||
} |
||||
|
||||
// RunUnix create a unix listener and start goroutine for serving each incoming request.
|
||||
// RunUnix will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) RunUnix(file string) error { |
||||
lis, err := net.Listen("unix", file) |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
log.Error("failed to listen: %v", err) |
||||
return err |
||||
} |
||||
reflection.Register(s.server) |
||||
return s.Serve(lis) |
||||
} |
||||
|
||||
// Start create a new goroutine run server with configured listen addr
|
||||
// will panic if any error happend
|
||||
// return server itself
|
||||
func (s *Server) Start() (*Server, error) { |
||||
lis, err := net.Listen(s.conf.Network, s.conf.Addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
reflection.Register(s.server) |
||||
go func() { |
||||
if err := s.Serve(lis); err != nil { |
||||
panic(err) |
||||
} |
||||
}() |
||||
return s, nil |
||||
} |
||||
|
||||
// Serve accepts incoming connections on the listener lis, creating a new
|
||||
// ServerTransport and service goroutine for each.
|
||||
// Serve will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) Serve(lis net.Listener) error { |
||||
return s.server.Serve(lis) |
||||
} |
||||
|
||||
// Shutdown stops the server gracefully. It stops the server from
|
||||
// accepting new connections and RPCs and blocks until all the pending RPCs are
|
||||
// finished or the context deadline is reached.
|
||||
func (s *Server) Shutdown(ctx context.Context) (err error) { |
||||
ch := make(chan struct{}) |
||||
go func() { |
||||
s.server.GracefulStop() |
||||
close(ch) |
||||
}() |
||||
select { |
||||
case <-ctx.Done(): |
||||
s.server.Stop() |
||||
err = ctx.Err() |
||||
case <-ch: |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,570 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"math/rand" |
||||
"net" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
"github.com/bilibili/kratos/pkg/log" |
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata" |
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker" |
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto" |
||||
xtrace "github.com/bilibili/kratos/pkg/net/trace" |
||||
xtime "github.com/bilibili/kratos/pkg/time" |
||||
|
||||
"github.com/pkg/errors" |
||||
"github.com/stretchr/testify/assert" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/codes" |
||||
"google.golang.org/grpc/status" |
||||
) |
||||
|
||||
const ( |
||||
_separator = "\001" |
||||
) |
||||
|
||||
var ( |
||||
outPut []string |
||||
_testOnce sync.Once |
||||
server *Server |
||||
|
||||
clientConfig = ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
} |
||||
clientConfig2 = ClientConfig{ |
||||
Dial: xtime.Duration(time.Second * 10), |
||||
Timeout: xtime.Duration(time.Second * 10), |
||||
Breaker: &breaker.Config{ |
||||
Window: xtime.Duration(3 * time.Second), |
||||
Sleep: xtime.Duration(3 * time.Second), |
||||
Bucket: 10, |
||||
Ratio: 0.3, |
||||
Request: 20, |
||||
}, |
||||
Method: map[string]*ClientConfig{`/testproto.Greeter/SayHello`: {Timeout: xtime.Duration(time.Millisecond * 200)}}, |
||||
} |
||||
) |
||||
|
||||
type helloServer struct { |
||||
t *testing.T |
||||
} |
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
if in.Name == "trace_test" { |
||||
t, isok := xtrace.FromContext(ctx) |
||||
if !isok { |
||||
t = xtrace.New("test title") |
||||
s.t.Fatalf("no trace extracted from server context") |
||||
} |
||||
newCtx := xtrace.NewContext(ctx, t) |
||||
if in.Age == 0 { |
||||
runClient(newCtx, &clientConfig, s.t, "trace_test", 1) |
||||
} |
||||
} else if in.Name == "recovery_test" { |
||||
panic("test recovery") |
||||
} else if in.Name == "graceful_shutdown" { |
||||
time.Sleep(time.Second * 3) |
||||
} else if in.Name == "timeout_test" { |
||||
if in.Age > 10 { |
||||
s.t.Fatalf("can not deliver requests over 10 times because of link timeout") |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil |
||||
} |
||||
time.Sleep(time.Millisecond * 10) |
||||
_, err := runClient(ctx, &clientConfig, s.t, "timeout_test", in.Age+1) |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err |
||||
} else if in.Name == "timeout_test2" { |
||||
if in.Age > 10 { |
||||
s.t.Fatalf("can not deliver requests over 10 times because of link timeout") |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil |
||||
} |
||||
time.Sleep(time.Millisecond * 10) |
||||
_, err := runClient(ctx, &clientConfig2, s.t, "timeout_test2", in.Age+1) |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err |
||||
} else if in.Name == "color_test" { |
||||
if in.Age == 0 { |
||||
resp, err := runClient(ctx, &clientConfig, s.t, "color_test", in.Age+1) |
||||
return resp, err |
||||
} |
||||
color := nmd.String(ctx, nmd.Color) |
||||
return &pb.HelloReply{Message: "Hello " + color, Success: true}, nil |
||||
} else if in.Name == "breaker_test" { |
||||
if rand.Intn(100) <= 50 { |
||||
return nil, status.Errorf(codes.ResourceExhausted, "test") |
||||
} |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil |
||||
} else if in.Name == "error_detail" { |
||||
err, _ := ecode.Error(ecode.Code(123456), "test_error_detail").WithDetails(&pb.HelloReply{Success: true}) |
||||
return nil, err |
||||
} else if in.Name == "ecode_status" { |
||||
reply := &pb.HelloReply{Message: "status", Success: true} |
||||
st, _ := ecode.Error(ecode.RequestErr, "RequestErr").WithDetails(reply) |
||||
return nil, st |
||||
} else if in.Name == "general_error" { |
||||
return nil, fmt.Errorf("haha is error") |
||||
} else if in.Name == "ecode_code_error" { |
||||
return nil, ecode.Conflict |
||||
} else if in.Name == "pb_error_error" { |
||||
return nil, ecode.Error(ecode.Code(11122), "haha") |
||||
} else if in.Name == "ecode_status_error" { |
||||
return nil, ecode.Error(ecode.RequestErr, "RequestErr") |
||||
} else if in.Name == "test_remote_port" { |
||||
if strconv.Itoa(int(in.Age)) != nmd.String(ctx, nmd.RemotePort) { |
||||
return nil, fmt.Errorf("error port %d", in.Age) |
||||
} |
||||
reply := &pb.HelloReply{Message: "status", Success: true} |
||||
return reply, nil |
||||
} |
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil |
||||
} |
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error { |
||||
for i := 0; i < 3; i++ { |
||||
in, err := ss.Recv() |
||||
if err == io.EOF { |
||||
return nil |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true} |
||||
err = ss.Send(ret) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func runServer(t *testing.T, interceptors ...grpc.UnaryServerInterceptor) func() { |
||||
return func() { |
||||
server = NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)}) |
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{t}) |
||||
server.Use( |
||||
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
outPut = append(outPut, "1") |
||||
resp, err := handler(ctx, req) |
||||
outPut = append(outPut, "2") |
||||
return resp, err |
||||
}, |
||||
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { |
||||
outPut = append(outPut, "3") |
||||
resp, err := handler(ctx, req) |
||||
outPut = append(outPut, "4") |
||||
return resp, err |
||||
}) |
||||
if _, err := server.Start(); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func runClient(ctx context.Context, cc *ClientConfig, t *testing.T, name string, age int32, interceptors ...grpc.UnaryClientInterceptor) (resp *pb.HelloReply, err error) { |
||||
client := NewClient(cc) |
||||
client.Use(interceptors...) |
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080") |
||||
if err != nil { |
||||
panic(fmt.Errorf("did not connect: %v,req: %v %v", err, name, age)) |
||||
} |
||||
defer conn.Close() |
||||
c := pb.NewGreeterClient(conn) |
||||
resp, err = c.SayHello(ctx, &pb.HelloRequest{Name: name, Age: age}) |
||||
return |
||||
} |
||||
|
||||
func TestMain(t *testing.T) { |
||||
log.Init(nil) |
||||
} |
||||
|
||||
func Test_Warden(t *testing.T) { |
||||
xtrace.Init(&xtrace.Config{Addr: "127.0.0.1:9982", Timeout: xtime.Duration(time.Second * 3)}) |
||||
go _testOnce.Do(runServer(t)) |
||||
go runClient(context.Background(), &clientConfig, t, "trace_test", 0) |
||||
testTrace(t, 9982, false) |
||||
testInterceptorChain(t) |
||||
testValidation(t) |
||||
testServerRecovery(t) |
||||
testClientRecovery(t) |
||||
testErrorDetail(t) |
||||
testECodeStatus(t) |
||||
testColorPass(t) |
||||
testRemotePort(t) |
||||
testLinkTimeout(t) |
||||
testClientConfig(t) |
||||
testBreaker(t) |
||||
testAllErrorCase(t) |
||||
testGracefulShutDown(t) |
||||
} |
||||
|
||||
func testValidation(t *testing.T) { |
||||
_, err := runClient(context.Background(), &clientConfig, t, "", 0) |
||||
if !ecode.RequestErr.Equal(err) { |
||||
t.Fatalf("testValidation should return ecode.RequestErr,but is %v", err) |
||||
} |
||||
} |
||||
|
||||
func testAllErrorCase(t *testing.T) { |
||||
// } else if in.Name == "general_error" {
|
||||
// return nil, fmt.Errorf("haha is error")
|
||||
// } else if in.Name == "ecode_code_error" {
|
||||
// return nil, ecode.CreativeArticleTagErr
|
||||
// } else if in.Name == "pb_error_error" {
|
||||
// return nil, &errpb.Error{ErrCode: 11122, ErrMessage: "haha"}
|
||||
// } else if in.Name == "ecode_status_error" {
|
||||
// return nil, ecode.Error(ecode.RequestErr, "RequestErr")
|
||||
// }
|
||||
ctx := context.Background() |
||||
t.Run("general_error", func(t *testing.T) { |
||||
_, err := runClient(ctx, &clientConfig, t, "general_error", 0) |
||||
assert.Contains(t, err.Error(), "haha") |
||||
ec := ecode.Cause(err) |
||||
assert.Equal(t, -500, ec.Code()) |
||||
// remove this assert in future
|
||||
assert.Equal(t, "-500", ec.Message()) |
||||
}) |
||||
t.Run("ecode_code_error", func(t *testing.T) { |
||||
_, err := runClient(ctx, &clientConfig, t, "ecode_code_error", 0) |
||||
ec := ecode.Cause(err) |
||||
assert.Equal(t, ecode.Conflict.Code(), ec.Code()) |
||||
// remove this assert in future
|
||||
assert.Equal(t, "20024", ec.Message()) |
||||
}) |
||||
t.Run("pb_error_error", func(t *testing.T) { |
||||
_, err := runClient(ctx, &clientConfig, t, "pb_error_error", 0) |
||||
ec := ecode.Cause(err) |
||||
assert.Equal(t, 11122, ec.Code()) |
||||
assert.Equal(t, "haha", ec.Message()) |
||||
}) |
||||
t.Run("ecode_status_error", func(t *testing.T) { |
||||
_, err := runClient(ctx, &clientConfig, t, "ecode_status_error", 0) |
||||
ec := ecode.Cause(err) |
||||
assert.Equal(t, ecode.RequestErr.Code(), ec.Code()) |
||||
assert.Equal(t, "RequestErr", ec.Message()) |
||||
}) |
||||
} |
||||
|
||||
func testBreaker(t *testing.T) { |
||||
client := NewClient(&clientConfig) |
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080") |
||||
if err != nil { |
||||
t.Fatalf("did not connect: %v", err) |
||||
} |
||||
defer conn.Close() |
||||
c := pb.NewGreeterClient(conn) |
||||
for i := 0; i < 35; i++ { |
||||
_, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "breaker_test"}) |
||||
if err != nil { |
||||
if ecode.ServiceUnavailable.Equal(err) { |
||||
return |
||||
} |
||||
} |
||||
} |
||||
t.Fatalf("testBreaker failed!No breaker was triggered") |
||||
} |
||||
|
||||
func testColorPass(t *testing.T) { |
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{ |
||||
nmd.Color: "red", |
||||
}) |
||||
resp, err := runClient(ctx, &clientConfig, t, "color_test", 0) |
||||
if err != nil { |
||||
t.Fatalf("testColorPass return error %v", err) |
||||
} |
||||
if resp == nil || resp.Message != "Hello red" { |
||||
t.Fatalf("testColorPass resp.Message must be red,%v", *resp) |
||||
} |
||||
} |
||||
|
||||
func testRemotePort(t *testing.T) { |
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{ |
||||
nmd.RemotePort: "8000", |
||||
}) |
||||
_, err := runClient(ctx, &clientConfig, t, "test_remote_port", 8000) |
||||
if err != nil { |
||||
t.Fatalf("testRemotePort return error %v", err) |
||||
} |
||||
} |
||||
|
||||
func testLinkTimeout(t *testing.T) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) |
||||
defer cancel() |
||||
_, err := runClient(ctx, &clientConfig, t, "timeout_test", 0) |
||||
if err == nil { |
||||
t.Fatalf("testLinkTimeout must return error") |
||||
} |
||||
if !ecode.Deadline.Equal(err) { |
||||
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err) |
||||
} |
||||
|
||||
} |
||||
func testClientConfig(t *testing.T) { |
||||
_, err := runClient(context.Background(), &clientConfig2, t, "timeout_test2", 0) |
||||
if err == nil { |
||||
t.Fatalf("testLinkTimeout must return error") |
||||
} |
||||
if !ecode.Deadline.Equal(err) { |
||||
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err) |
||||
} |
||||
} |
||||
|
||||
func testGracefulShutDown(t *testing.T) { |
||||
wg := sync.WaitGroup{} |
||||
for i := 0; i < 10; i++ { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
resp, err := runClient(context.Background(), &clientConfig, t, "graceful_shutdown", 0) |
||||
if err != nil { |
||||
panic(fmt.Errorf("run graceful_shutdown client return(%v)", err)) |
||||
} |
||||
if !resp.Success || resp.Message != "Hello graceful_shutdown" { |
||||
panic(fmt.Errorf("run graceful_shutdown client return(%v,%v)", err, *resp)) |
||||
} |
||||
}() |
||||
} |
||||
go func() { |
||||
time.Sleep(time.Second) |
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) |
||||
defer cancel() |
||||
server.Shutdown(ctx) |
||||
}() |
||||
wg.Wait() |
||||
} |
||||
|
||||
func testClientRecovery(t *testing.T) { |
||||
ctx := context.Background() |
||||
client := NewClient(&clientConfig) |
||||
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) { |
||||
invoker(ctx, method, req, reply, cc, opts...) |
||||
panic("client recovery test") |
||||
}) |
||||
|
||||
conn, err := client.Dial(ctx, "127.0.0.1:8080") |
||||
if err != nil { |
||||
t.Fatalf("did not connect: %v", err) |
||||
} |
||||
defer conn.Close() |
||||
c := pb.NewGreeterClient(conn) |
||||
|
||||
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "other_test", Age: 0}) |
||||
if err == nil { |
||||
t.Fatalf("recovery must return error") |
||||
} |
||||
e, ok := errors.Cause(err).(ecode.Codes) |
||||
if !ok { |
||||
t.Fatalf("recovery must return ecode error") |
||||
} |
||||
|
||||
if !ecode.ServerErr.Equal(e) { |
||||
t.Fatalf("recovery must return ecode.RPCClientErr") |
||||
} |
||||
} |
||||
|
||||
func testServerRecovery(t *testing.T) { |
||||
ctx := context.Background() |
||||
client := NewClient(&clientConfig) |
||||
|
||||
conn, err := client.Dial(ctx, "127.0.0.1:8080") |
||||
if err != nil { |
||||
t.Fatalf("did not connect: %v", err) |
||||
} |
||||
defer conn.Close() |
||||
c := pb.NewGreeterClient(conn) |
||||
|
||||
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "recovery_test", Age: 0}) |
||||
if err == nil { |
||||
t.Fatalf("recovery must return error") |
||||
} |
||||
e, ok := errors.Cause(err).(ecode.Codes) |
||||
if !ok { |
||||
t.Fatalf("recovery must return ecode error") |
||||
} |
||||
|
||||
if e.Code() != ecode.ServerErr.Code() { |
||||
t.Fatalf("recovery must return ecode.ServerErr") |
||||
} |
||||
} |
||||
|
||||
func testInterceptorChain(t *testing.T) { |
||||
// NOTE: don't delete this sleep
|
||||
time.Sleep(time.Millisecond) |
||||
if outPut[0] != "1" || outPut[1] != "3" || outPut[2] != "1" || outPut[3] != "3" || outPut[4] != "4" || outPut[5] != "2" || outPut[6] != "4" || outPut[7] != "2" { |
||||
t.Fatalf("outPut shoud be [1 3 1 3 4 2 4 2]!") |
||||
} |
||||
} |
||||
|
||||
func testErrorDetail(t *testing.T) { |
||||
_, err := runClient(context.Background(), &clientConfig2, t, "error_detail", 0) |
||||
if err == nil { |
||||
t.Fatalf("testErrorDetail must return error") |
||||
} |
||||
if ec, ok := errors.Cause(err).(ecode.Codes); !ok { |
||||
t.Fatalf("testErrorDetail must return ecode error") |
||||
} else if ec.Code() != 123456 || ec.Message() != "test_error_detail" || len(ec.Details()) == 0 { |
||||
t.Fatalf("testErrorDetail must return code:123456 and message:test_error_detail, code: %d, message: %s, details length: %d", ec.Code(), ec.Message(), len(ec.Details())) |
||||
} else if _, ok := ec.Details()[len(ec.Details())-1].(*pb.HelloReply); !ok { |
||||
t.Fatalf("expect get pb.HelloReply %#v", ec.Details()[len(ec.Details())-1]) |
||||
} |
||||
} |
||||
|
||||
func testECodeStatus(t *testing.T) { |
||||
_, err := runClient(context.Background(), &clientConfig2, t, "ecode_status", 0) |
||||
if err == nil { |
||||
t.Fatalf("testECodeStatus must return error") |
||||
} |
||||
st, ok := errors.Cause(err).(*ecode.Status) |
||||
if !ok { |
||||
t.Fatalf("testECodeStatus must return *ecode.Status") |
||||
} |
||||
if st.Code() != int(ecode.RequestErr) && st.Message() != "RequestErr" { |
||||
t.Fatalf("testECodeStatus must return code: -400, message: RequestErr get: code: %d, message: %s", st.Code(), st.Message()) |
||||
} |
||||
detail := st.Details()[0].(*pb.HelloReply) |
||||
if !detail.Success || detail.Message != "status" { |
||||
t.Fatalf("wrong detail") |
||||
} |
||||
} |
||||
|
||||
func testTrace(t *testing.T, port int, isStream bool) { |
||||
listener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port}) |
||||
if err != nil { |
||||
t.Fatalf("listent udp failed, %v", err) |
||||
return |
||||
} |
||||
data := make([]byte, 1024) |
||||
strs := make([][]string, 0) |
||||
for { |
||||
var n int |
||||
n, _, err = listener.ReadFromUDP(data) |
||||
if err != nil { |
||||
t.Fatalf("read from udp faild, %v", err) |
||||
} |
||||
str := strings.Split(string(data[:n]), _separator) |
||||
strs = append(strs, str) |
||||
|
||||
if len(strs) == 2 { |
||||
break |
||||
} |
||||
} |
||||
if len(strs[0]) == 0 || len(strs[1]) == 0 { |
||||
t.Fatalf("trace str's length must be greater than 0") |
||||
} |
||||
} |
||||
|
||||
func BenchmarkServer(b *testing.B) { |
||||
server := NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)}) |
||||
go func() { |
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{}) |
||||
if _, err := server.Start(); err != nil { |
||||
os.Exit(0) |
||||
return |
||||
} |
||||
}() |
||||
defer func() { |
||||
server.Server().Stop() |
||||
}() |
||||
client := NewClient(&clientConfig) |
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080") |
||||
if err != nil { |
||||
conn.Close() |
||||
b.Fatalf("did not connect: %v", err) |
||||
} |
||||
b.ResetTimer() |
||||
b.RunParallel(func(parab *testing.PB) { |
||||
for parab.Next() { |
||||
c := pb.NewGreeterClient(conn) |
||||
resp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "benchmark_test", Age: 1}) |
||||
if err != nil { |
||||
conn.Close() |
||||
b.Fatalf("c.SayHello failed: %v,req: %v %v", err, "benchmark", 1) |
||||
} |
||||
if !resp.Success { |
||||
b.Error("repsonse not success!") |
||||
} |
||||
} |
||||
}) |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestParseDSN(t *testing.T) { |
||||
dsn := "tcp://0.0.0.0:80/?timeout=100ms&idleTimeout=120s&keepaliveInterval=120s&keepaliveTimeout=20s&maxLife=4h&closeWait=3s" |
||||
config := parseDSN(dsn) |
||||
if config.Network != "tcp" || config.Addr != "0.0.0.0:80" || time.Duration(config.Timeout) != time.Millisecond*100 || |
||||
time.Duration(config.IdleTimeout) != time.Second*120 || time.Duration(config.KeepAliveInterval) != time.Second*120 || |
||||
time.Duration(config.MaxLifeTime) != time.Hour*4 || time.Duration(config.ForceCloseWait) != time.Second*3 || time.Duration(config.KeepAliveTimeout) != time.Second*20 { |
||||
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config) |
||||
} |
||||
|
||||
dsn = "unix:///temp/warden.sock?timeout=300ms" |
||||
config = parseDSN(dsn) |
||||
if config.Network != "unix" || config.Addr != "/temp/warden.sock" || time.Duration(config.Timeout) != time.Millisecond*300 { |
||||
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config) |
||||
} |
||||
} |
||||
|
||||
type testServer struct { |
||||
helloFn func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) |
||||
} |
||||
|
||||
func (t *testServer) SayHello(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
return t.helloFn(ctx, req) |
||||
} |
||||
|
||||
func (t *testServer) StreamHello(pb.Greeter_StreamHelloServer) error { panic("not implemented") } |
||||
|
||||
// NewTestServerClient .
|
||||
func NewTestServerClient(invoker func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error), svrcfg *ServerConfig, clicfg *ClientConfig) (pb.GreeterClient, func() error) { |
||||
srv := NewServer(svrcfg) |
||||
pb.RegisterGreeterServer(srv.Server(), &testServer{helloFn: invoker}) |
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
ch := make(chan bool, 1) |
||||
go func() { |
||||
ch <- true |
||||
srv.Serve(lis) |
||||
}() |
||||
<-ch |
||||
println(lis.Addr().String()) |
||||
conn, err := NewConn(lis.Addr().String()) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return pb.NewGreeterClient(conn), func() error { return srv.Shutdown(context.Background()) } |
||||
} |
||||
|
||||
func TestMetadata(t *testing.T) { |
||||
cli, cancel := NewTestServerClient(func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) { |
||||
assert.Equal(t, "red", nmd.String(ctx, nmd.Color)) |
||||
assert.Equal(t, "2.2.3.3", nmd.String(ctx, nmd.RemoteIP)) |
||||
assert.Equal(t, "2233", nmd.String(ctx, nmd.RemotePort)) |
||||
return &pb.HelloReply{}, nil |
||||
}, nil, nil) |
||||
defer cancel() |
||||
|
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{ |
||||
nmd.Color: "red", |
||||
nmd.RemoteIP: "2.2.3.3", |
||||
nmd.RemotePort: "2233", |
||||
}) |
||||
_, err := cli.SayHello(ctx, &pb.HelloRequest{Name: "test"}) |
||||
assert.Nil(t, err) |
||||
} |
@ -0,0 +1,25 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
"strconv" |
||||
|
||||
nmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata" |
||||
"github.com/bilibili/kratos/pkg/stat/sys/cpu" |
||||
|
||||
"google.golang.org/grpc" |
||||
gmd "google.golang.org/grpc/metadata" |
||||
) |
||||
|
||||
func (s *Server) stats() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
||||
resp, err = handler(ctx, req) |
||||
var cpustat cpu.Stat |
||||
cpu.ReadStat(&cpustat) |
||||
if cpustat.Usage != 0 { |
||||
trailer := gmd.Pairs([]string{nmd.CPUUsage, strconv.FormatInt(int64(cpustat.Usage), 10)}...) |
||||
grpc.SetTrailer(ctx, trailer) |
||||
} |
||||
return |
||||
} |
||||
} |
@ -0,0 +1,31 @@ |
||||
package warden |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode" |
||||
|
||||
"google.golang.org/grpc" |
||||
"gopkg.in/go-playground/validator.v9" |
||||
) |
||||
|
||||
var validate = validator.New() |
||||
|
||||
// Validate return a client interceptor validate incoming request per RPC call.
|
||||
func (s *Server) validate() grpc.UnaryServerInterceptor { |
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { |
||||
if err = validate.Struct(req); err != nil { |
||||
err = ecode.Error(ecode.RequestErr, err.Error()) |
||||
return |
||||
} |
||||
resp, err = handler(ctx, req) |
||||
return |
||||
} |
||||
} |
||||
|
||||
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||
func (s *Server) RegisterValidation(key string, fn validator.Func) error { |
||||
return validate.RegisterValidation(key, fn) |
||||
} |
Loading…
Reference in new issue