From b3eff576ce0d5cbadf53ec2a9ab1232d45e5cef3 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Sun, 10 Jul 2022 20:34:19 +0800 Subject: [PATCH] =?UTF-8?q?test(transport):=20add=20unit=20test=20for=20tr?= =?UTF-8?q?ansport=20coverage:=2091.7%=20of=20state=E2=80=A6=20(#2172)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test(transport): add unit test for transport coverage: 91.7% of statements * lint & fix data race * fix test * fix lint * fix lint * remove http func wrapper * remove error log when watcher context is canceled * optimize code style --- internal/testdata/binding/generate.go | 3 + internal/testdata/binding/test.pb.go | 47 ++++--- internal/testdata/binding/test.proto | 1 + internal/testdata/helloworld/generate.go | 3 + internal/testdata/helloworld/helloworld.pb.go | 118 +++++++++--------- internal/testdata/helloworld/helloworld.proto | 2 + .../testdata/helloworld/helloworld_grpc.pb.go | 79 +++++++++++- .../testdata/helloworld/helloworld_http.pb.go | 8 +- .../grpc/resolver/direct/builder_test.go | 15 ++- transport/grpc/resolver/discovery/builder.go | 14 ++- .../grpc/resolver/discovery/builder_test.go | 33 ++++- transport/grpc/resolver/discovery/resolver.go | 10 +- .../grpc/resolver/discovery/resolver_test.go | 12 +- transport/grpc/server_test.go | 57 ++++++++- transport/http/binding/encode.go | 13 +- transport/http/binding/encode_test.go | 28 +++++ transport/http/client_test.go | 63 +++++++++- transport/http/context_test.go | 7 ++ transport/http/resolver_test.go | 74 +++++++++-- transport/http/server.go | 1 + transport/http/server_test.go | 51 ++++++-- transport/transport_test.go | 3 + 22 files changed, 516 insertions(+), 126 deletions(-) create mode 100644 internal/testdata/binding/generate.go create mode 100644 internal/testdata/helloworld/generate.go diff --git a/internal/testdata/binding/generate.go b/internal/testdata/binding/generate.go new file mode 100644 index 000000000..7abe42f8e --- /dev/null +++ b/internal/testdata/binding/generate.go @@ -0,0 +1,3 @@ +package binding + +//go:generate protoc -I . --go_out=paths=source_relative:. ./test.proto diff --git a/internal/testdata/binding/test.pb.go b/internal/testdata/binding/test.pb.go index 3a99e6d9d..a0f909199 100644 --- a/internal/testdata/binding/test.pb.go +++ b/internal/testdata/binding/test.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 -// protoc v3.20.0 +// protoc v3.17.3 // source: test.proto package binding @@ -27,13 +27,14 @@ type HelloRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Sub *Sub `protobuf:"bytes,2,opt,name=sub,proto3" json:"sub,omitempty"` - UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,3,opt,name=update_mask,json=updateMask,proto3" json:"update_mask,omitempty"` - OptInt32 *int32 `protobuf:"varint,4,opt,name=opt_int32,json=optInt32,proto3,oneof" json:"opt_int32,omitempty"` - OptInt64 *int64 `protobuf:"varint,5,opt,name=opt_int64,json=optInt64,proto3,oneof" json:"opt_int64,omitempty"` - OptString *string `protobuf:"bytes,6,opt,name=opt_string,json=optString,proto3,oneof" json:"opt_string,omitempty"` - SubField *Sub `protobuf:"bytes,7,opt,name=subField,proto3" json:"subField,omitempty"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Sub *Sub `protobuf:"bytes,2,opt,name=sub,proto3" json:"sub,omitempty"` + UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,3,opt,name=update_mask,json=updateMask,proto3" json:"update_mask,omitempty"` + OptInt32 *int32 `protobuf:"varint,4,opt,name=opt_int32,json=optInt32,proto3,oneof" json:"opt_int32,omitempty"` + OptInt64 *int64 `protobuf:"varint,5,opt,name=opt_int64,json=optInt64,proto3,oneof" json:"opt_int64,omitempty"` + OptString *string `protobuf:"bytes,6,opt,name=opt_string,json=optString,proto3,oneof" json:"opt_string,omitempty"` + SubField *Sub `protobuf:"bytes,7,opt,name=subField,proto3" json:"subField,omitempty"` + TestRepeated []string `protobuf:"bytes,8,rep,name=test_repeated,proto3" json:"test_repeated,omitempty"` } func (x *HelloRequest) Reset() { @@ -117,6 +118,13 @@ func (x *HelloRequest) GetSubField() *Sub { return nil } +func (x *HelloRequest) GetTestRepeated() []string { + if x != nil { + return x.TestRepeated + } + return nil +} + type Sub struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -170,7 +178,7 @@ var file_test_proto_rawDesc = []byte{ 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x62, 0x69, 0x6e, 0x64, 0x69, 0x6e, 0x67, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6d, 0x61, 0x73, - 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xbc, 0x02, 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, + 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe2, 0x02, 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1e, 0x0a, 0x03, 0x73, 0x75, 0x62, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, 0x2e, 0x62, 0x69, 0x6e, 0x64, @@ -187,15 +195,18 @@ var file_test_proto_rawDesc = []byte{ 0x09, 0x48, 0x02, 0x52, 0x09, 0x6f, 0x70, 0x74, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x88, 0x01, 0x01, 0x12, 0x28, 0x0a, 0x08, 0x73, 0x75, 0x62, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0c, 0x2e, 0x62, 0x69, 0x6e, 0x64, 0x69, 0x6e, 0x67, 0x2e, 0x53, 0x75, - 0x62, 0x52, 0x08, 0x73, 0x75, 0x62, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, - 0x6f, 0x70, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x6f, 0x70, - 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x6f, 0x70, 0x74, 0x5f, - 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x22, 0x1b, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x12, 0x14, 0x0a, - 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6e, 0x61, 0x6d, - 0x69, 0x6e, 0x67, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x6b, 0x72, 0x61, 0x74, - 0x6f, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x62, 0x69, 0x6e, - 0x64, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x62, 0x52, 0x08, 0x73, 0x75, 0x62, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x24, 0x0a, 0x0d, 0x74, + 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x08, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0d, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, + 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x42, + 0x0c, 0x0a, 0x0a, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x42, 0x0d, 0x0a, + 0x0b, 0x5f, 0x6f, 0x70, 0x74, 0x5f, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x22, 0x1b, 0x0a, 0x03, + 0x53, 0x75, 0x62, 0x12, 0x14, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x6b, 0x72, 0x61, 0x74, 0x6f, + 0x73, 0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x72, 0x74, 0x2f, 0x62, 0x69, 0x6e, 0x64, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/internal/testdata/binding/test.proto b/internal/testdata/binding/test.proto index 613bd19d8..f6df42fc4 100644 --- a/internal/testdata/binding/test.proto +++ b/internal/testdata/binding/test.proto @@ -15,6 +15,7 @@ message HelloRequest { optional int64 opt_int64 = 5; optional string opt_string = 6; Sub subField = 7; + repeated string test_repeated = 8 [json_name = "test_repeated"]; } message Sub{ diff --git a/internal/testdata/helloworld/generate.go b/internal/testdata/helloworld/generate.go new file mode 100644 index 000000000..05de6fcb3 --- /dev/null +++ b/internal/testdata/helloworld/generate.go @@ -0,0 +1,3 @@ +package helloworld + +//go:generate protoc -I . -I ../../../third_party --go_out=paths=source_relative:. --go-grpc_out=paths=source_relative:. --go-http_out=paths=source_relative:. ./helloworld.proto diff --git a/internal/testdata/helloworld/helloworld.pb.go b/internal/testdata/helloworld/helloworld.pb.go index 656819d43..48d44179e 100644 --- a/internal/testdata/helloworld/helloworld.pb.go +++ b/internal/testdata/helloworld/helloworld.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.17.3 -// source: helloworld/helloworld.proto +// source: helloworld.proto package helloworld @@ -33,7 +33,7 @@ type HelloRequest struct { func (x *HelloRequest) Reset() { *x = HelloRequest{} if protoimpl.UnsafeEnabled { - mi := &file_helloworld_helloworld_proto_msgTypes[0] + mi := &file_helloworld_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -46,7 +46,7 @@ func (x *HelloRequest) String() string { func (*HelloRequest) ProtoMessage() {} func (x *HelloRequest) ProtoReflect() protoreflect.Message { - mi := &file_helloworld_helloworld_proto_msgTypes[0] + mi := &file_helloworld_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -59,7 +59,7 @@ func (x *HelloRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use HelloRequest.ProtoReflect.Descriptor instead. func (*HelloRequest) Descriptor() ([]byte, []int) { - return file_helloworld_helloworld_proto_rawDescGZIP(), []int{0} + return file_helloworld_proto_rawDescGZIP(), []int{0} } func (x *HelloRequest) GetName() string { @@ -81,7 +81,7 @@ type HelloReply struct { func (x *HelloReply) Reset() { *x = HelloReply{} if protoimpl.UnsafeEnabled { - mi := &file_helloworld_helloworld_proto_msgTypes[1] + mi := &file_helloworld_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -94,7 +94,7 @@ func (x *HelloReply) String() string { func (*HelloReply) ProtoMessage() {} func (x *HelloReply) ProtoReflect() protoreflect.Message { - mi := &file_helloworld_helloworld_proto_msgTypes[1] + mi := &file_helloworld_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -107,7 +107,7 @@ func (x *HelloReply) ProtoReflect() protoreflect.Message { // Deprecated: Use HelloReply.ProtoReflect.Descriptor instead. func (*HelloReply) Descriptor() ([]byte, []int) { - return file_helloworld_helloworld_proto_rawDescGZIP(), []int{1} + return file_helloworld_proto_rawDescGZIP(), []int{1} } func (x *HelloReply) GetMessage() string { @@ -117,65 +117,71 @@ func (x *HelloReply) GetMessage() string { return "" } -var File_helloworld_helloworld_proto protoreflect.FileDescriptor - -var file_helloworld_helloworld_proto_rawDesc = []byte{ - 0x0a, 0x1b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2f, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x68, - 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x22, 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x26, 0x0a, 0x0a, 0x48, - 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x32, 0x63, 0x0a, 0x07, 0x47, 0x72, 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x58, - 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, - 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x1a, 0x82, 0xd3, - 0xe4, 0x93, 0x02, 0x14, 0x12, 0x12, 0x2f, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, - 0x64, 0x2f, 0x7b, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x42, 0x3d, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, - 0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x76, 0x32, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, - 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, 0x68, 0x65, 0x6c, - 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +var File_helloworld_proto protoreflect.FileDescriptor + +var file_helloworld_proto_rawDesc = []byte{ + 0x0a, 0x10, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x0a, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x1a, 0x1c, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x22, 0x0a, 0x0c, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x22, 0x26, 0x0a, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, + 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0xab, 0x01, 0x0a, 0x07, 0x47, 0x72, 0x65, + 0x65, 0x74, 0x65, 0x72, 0x12, 0x58, 0x0a, 0x08, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, + 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, + 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x22, 0x1a, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x14, 0x12, 0x12, 0x2f, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2f, 0x7b, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, 0x46, + 0x0a, 0x0e, 0x53, 0x61, 0x79, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x12, 0x18, 0x2e, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, + 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x68, 0x65, 0x6c, + 0x6c, 0x6f, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x28, 0x01, 0x30, 0x01, 0x42, 0x3d, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x6b, + 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x76, 0x32, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, + 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2f, 0x68, 0x65, 0x6c, 0x6c, 0x6f, + 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( - file_helloworld_helloworld_proto_rawDescOnce sync.Once - file_helloworld_helloworld_proto_rawDescData = file_helloworld_helloworld_proto_rawDesc + file_helloworld_proto_rawDescOnce sync.Once + file_helloworld_proto_rawDescData = file_helloworld_proto_rawDesc ) -func file_helloworld_helloworld_proto_rawDescGZIP() []byte { - file_helloworld_helloworld_proto_rawDescOnce.Do(func() { - file_helloworld_helloworld_proto_rawDescData = protoimpl.X.CompressGZIP(file_helloworld_helloworld_proto_rawDescData) +func file_helloworld_proto_rawDescGZIP() []byte { + file_helloworld_proto_rawDescOnce.Do(func() { + file_helloworld_proto_rawDescData = protoimpl.X.CompressGZIP(file_helloworld_proto_rawDescData) }) - return file_helloworld_helloworld_proto_rawDescData + return file_helloworld_proto_rawDescData } -var file_helloworld_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_helloworld_helloworld_proto_goTypes = []interface{}{ +var file_helloworld_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_helloworld_proto_goTypes = []interface{}{ (*HelloRequest)(nil), // 0: helloworld.HelloRequest (*HelloReply)(nil), // 1: helloworld.HelloReply } -var file_helloworld_helloworld_proto_depIdxs = []int32{ +var file_helloworld_proto_depIdxs = []int32{ 0, // 0: helloworld.Greeter.SayHello:input_type -> helloworld.HelloRequest - 1, // 1: helloworld.Greeter.SayHello:output_type -> helloworld.HelloReply - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 0, // 1: helloworld.Greeter.SayHelloStream:input_type -> helloworld.HelloRequest + 1, // 2: helloworld.Greeter.SayHello:output_type -> helloworld.HelloReply + 1, // 3: helloworld.Greeter.SayHelloStream:output_type -> helloworld.HelloReply + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } -func init() { file_helloworld_helloworld_proto_init() } -func file_helloworld_helloworld_proto_init() { - if File_helloworld_helloworld_proto != nil { +func init() { file_helloworld_proto_init() } +func file_helloworld_proto_init() { + if File_helloworld_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_helloworld_helloworld_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_helloworld_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*HelloRequest); i { case 0: return &v.state @@ -187,7 +193,7 @@ func file_helloworld_helloworld_proto_init() { return nil } } - file_helloworld_helloworld_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_helloworld_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*HelloReply); i { case 0: return &v.state @@ -204,18 +210,18 @@ func file_helloworld_helloworld_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_helloworld_helloworld_proto_rawDesc, + RawDescriptor: file_helloworld_proto_rawDesc, NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, - GoTypes: file_helloworld_helloworld_proto_goTypes, - DependencyIndexes: file_helloworld_helloworld_proto_depIdxs, - MessageInfos: file_helloworld_helloworld_proto_msgTypes, + GoTypes: file_helloworld_proto_goTypes, + DependencyIndexes: file_helloworld_proto_depIdxs, + MessageInfos: file_helloworld_proto_msgTypes, }.Build() - File_helloworld_helloworld_proto = out.File - file_helloworld_helloworld_proto_rawDesc = nil - file_helloworld_helloworld_proto_goTypes = nil - file_helloworld_helloworld_proto_depIdxs = nil + File_helloworld_proto = out.File + file_helloworld_proto_rawDesc = nil + file_helloworld_proto_goTypes = nil + file_helloworld_proto_depIdxs = nil } diff --git a/internal/testdata/helloworld/helloworld.proto b/internal/testdata/helloworld/helloworld.proto index bdac62c1f..ed068a38b 100644 --- a/internal/testdata/helloworld/helloworld.proto +++ b/internal/testdata/helloworld/helloworld.proto @@ -14,6 +14,8 @@ service Greeter { get: "/helloworld/{name}", }; } + // Sends a greeting + rpc SayHelloStream (stream HelloRequest) returns (stream HelloReply); } // The request message containing the user's name. diff --git a/internal/testdata/helloworld/helloworld_grpc.pb.go b/internal/testdata/helloworld/helloworld_grpc.pb.go index 4e1c32b3d..3558cbe41 100644 --- a/internal/testdata/helloworld/helloworld_grpc.pb.go +++ b/internal/testdata/helloworld/helloworld_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.17.3 +// source: helloworld.proto package helloworld @@ -20,6 +24,8 @@ const _ = grpc.SupportPackageIsVersion7 type GreeterClient interface { // Sends a greeting SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) + // Sends a greeting + SayHelloStream(ctx context.Context, opts ...grpc.CallOption) (Greeter_SayHelloStreamClient, error) } type greeterClient struct { @@ -39,12 +45,45 @@ func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ... return out, nil } +func (c *greeterClient) SayHelloStream(ctx context.Context, opts ...grpc.CallOption) (Greeter_SayHelloStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &Greeter_ServiceDesc.Streams[0], "/helloworld.Greeter/SayHelloStream", opts...) + if err != nil { + return nil, err + } + x := &greeterSayHelloStreamClient{stream} + return x, nil +} + +type Greeter_SayHelloStreamClient interface { + Send(*HelloRequest) error + Recv() (*HelloReply, error) + grpc.ClientStream +} + +type greeterSayHelloStreamClient struct { + grpc.ClientStream +} + +func (x *greeterSayHelloStreamClient) Send(m *HelloRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *greeterSayHelloStreamClient) Recv() (*HelloReply, error) { + m := new(HelloReply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // GreeterServer is the server API for Greeter service. // All implementations must embed UnimplementedGreeterServer // for forward compatibility type GreeterServer interface { // Sends a greeting SayHello(context.Context, *HelloRequest) (*HelloReply, error) + // Sends a greeting + SayHelloStream(Greeter_SayHelloStreamServer) error mustEmbedUnimplementedGreeterServer() } @@ -55,6 +94,9 @@ type UnimplementedGreeterServer struct { func (UnimplementedGreeterServer) SayHello(context.Context, *HelloRequest) (*HelloReply, error) { return nil, status.Errorf(codes.Unimplemented, "method SayHello not implemented") } +func (UnimplementedGreeterServer) SayHelloStream(Greeter_SayHelloStreamServer) error { + return status.Errorf(codes.Unimplemented, "method SayHelloStream not implemented") +} func (UnimplementedGreeterServer) mustEmbedUnimplementedGreeterServer() {} // UnsafeGreeterServer may be embedded to opt out of forward compatibility for this service. @@ -86,6 +128,32 @@ func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(in return interceptor(ctx, in, info, handler) } +func _Greeter_SayHelloStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(GreeterServer).SayHelloStream(&greeterSayHelloStreamServer{stream}) +} + +type Greeter_SayHelloStreamServer interface { + Send(*HelloReply) error + Recv() (*HelloRequest, error) + grpc.ServerStream +} + +type greeterSayHelloStreamServer struct { + grpc.ServerStream +} + +func (x *greeterSayHelloStreamServer) Send(m *HelloReply) error { + return x.ServerStream.SendMsg(m) +} + +func (x *greeterSayHelloStreamServer) Recv() (*HelloRequest, error) { + m := new(HelloRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // Greeter_ServiceDesc is the grpc.ServiceDesc for Greeter service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -98,6 +166,13 @@ var Greeter_ServiceDesc = grpc.ServiceDesc{ Handler: _Greeter_SayHello_Handler, }, }, - Streams: []grpc.StreamDesc{}, - Metadata: "helloworld/helloworld.proto", + Streams: []grpc.StreamDesc{ + { + StreamName: "SayHelloStream", + Handler: _Greeter_SayHelloStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "helloworld.proto", } diff --git a/internal/testdata/helloworld/helloworld_http.pb.go b/internal/testdata/helloworld/helloworld_http.pb.go index 368b57186..b77f02ec2 100644 --- a/internal/testdata/helloworld/helloworld_http.pb.go +++ b/internal/testdata/helloworld/helloworld_http.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-http. DO NOT EDIT. // versions: -// protoc-gen-go-http v2.1.0 +// protoc-gen-go-http v2.3.1 package helloworld @@ -17,6 +17,8 @@ var _ = binding.EncodeURL const _ = http.SupportPackageIsVersion1 +const OperationGreeterSayHello = "/helloworld.Greeter/SayHello" + type GreeterHTTPServer interface { SayHello(context.Context, *HelloRequest) (*HelloReply, error) } @@ -35,7 +37,7 @@ func _Greeter_SayHello0_HTTP_Handler(srv GreeterHTTPServer) func(ctx http.Contex if err := ctx.BindVars(&in); err != nil { return err } - http.SetOperation(ctx, "/helloworld.Greeter/SayHello") + http.SetOperation(ctx, OperationGreeterSayHello) h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { return srv.SayHello(ctx, req.(*HelloRequest)) }) @@ -64,7 +66,7 @@ func (c *GreeterHTTPClientImpl) SayHello(ctx context.Context, in *HelloRequest, var out HelloReply pattern := "/helloworld/{name}" path := binding.EncodeURL(pattern, in, true) - opts = append(opts, http.Operation("/helloworld.Greeter/SayHello")) + opts = append(opts, http.Operation(OperationGreeterSayHello)) opts = append(opts, http.PathTemplate(pattern)) err := c.cc.Invoke(ctx, "GET", path, nil, &out, opts...) if err != nil { diff --git a/transport/grpc/resolver/direct/builder_test.go b/transport/grpc/resolver/direct/builder_test.go index a0d1d31ab..1cd6bf402 100644 --- a/transport/grpc/resolver/direct/builder_test.go +++ b/transport/grpc/resolver/direct/builder_test.go @@ -1,6 +1,7 @@ package direct import ( + "fmt" "reflect" "testing" @@ -15,9 +16,14 @@ func TestDirectBuilder_Scheme(t *testing.T) { } } -type mockConn struct{} +type mockConn struct { + needUpdateStateErr bool +} func (m *mockConn) UpdateState(resolver.State) error { + if m.needUpdateStateErr { + return fmt.Errorf("mock test needUpdateStateErr") + } return nil } @@ -38,4 +44,11 @@ func TestDirectBuilder_Build(t *testing.T) { t.Errorf("expect no error, got %v", err) } r.ResolveNow(resolver.ResolveNowOptions{}) + r.Close() + + // need update state err + _, err = b.Build(resolver.Target{}, &mockConn{needUpdateStateErr: true}, resolver.BuildOptions{}) + if err == nil { + t.Errorf("expect needUpdateStateErr, got nil") + } } diff --git a/transport/grpc/resolver/discovery/builder.go b/transport/grpc/resolver/discovery/builder.go index 7ce777176..a4425a518 100644 --- a/transport/grpc/resolver/discovery/builder.go +++ b/transport/grpc/resolver/discovery/builder.go @@ -59,18 +59,24 @@ func NewBuilder(d registry.Discovery, opts ...Option) resolver.Builder { } func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { - var ( + watchRes := &struct { err error w registry.Watcher - ) + }{} + done := make(chan struct{}, 1) ctx, cancel := context.WithCancel(context.Background()) go func() { - w, err = b.discoverer.Watch(ctx, strings.TrimPrefix(target.URL.Path, "/")) + w, err := b.discoverer.Watch(ctx, strings.TrimPrefix(target.URL.Path, "/")) + watchRes.w = w + watchRes.err = err close(done) }() + + var err error select { case <-done: + err = watchRes.err case <-time.After(b.timeout): err = errors.New("discovery create watcher overtime") } @@ -79,7 +85,7 @@ func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts res return nil, err } r := &discoveryResolver{ - w: w, + w: watchRes.w, cc: cc, ctx: ctx, cancel: cancel, diff --git a/transport/grpc/resolver/discovery/builder_test.go b/transport/grpc/resolver/discovery/builder_test.go index 7a06abfcd..11541cb76 100644 --- a/transport/grpc/resolver/discovery/builder_test.go +++ b/transport/grpc/resolver/discovery/builder_test.go @@ -28,6 +28,14 @@ func TestWithTimeout(t *testing.T) { } } +func TestDisableDebugLog(t *testing.T) { + o := &builder{} + DisableDebugLog()(o) + if !o.debugLogDisabled { + t.Errorf("expected debugLogDisabled true, got %v", o.debugLogDisabled) + } +} + type mockDiscovery struct{} func (m *mockDiscovery) GetService(ctx context.Context, serviceName string) ([]*registry.ServiceInstance, error) { @@ -35,6 +43,7 @@ func (m *mockDiscovery) GetService(ctx context.Context, serviceName string) ([]* } func (m *mockDiscovery) Watch(ctx context.Context, serviceName string) (registry.Watcher, error) { + time.Sleep(time.Microsecond * 500) return &testWatch{}, nil } @@ -62,9 +71,29 @@ func (m *mockConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.P } func TestBuilder_Build(t *testing.T) { - b := NewBuilder(&mockDiscovery{}) - _, err := b.Build(resolver.Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "gprc://authority/endpoint"}, &mockConn{}, resolver.BuildOptions{}) + b := NewBuilder(&mockDiscovery{}, DisableDebugLog()) + _, err := b.Build( + resolver.Target{ + Scheme: resolver.GetDefaultScheme(), + Endpoint: "gprc://authority/endpoint", + }, + &mockConn{}, + resolver.BuildOptions{}, + ) if err != nil { t.Errorf("expected no error, got %v", err) + return + } + timeoutBuilder := NewBuilder(&mockDiscovery{}, WithTimeout(0)) + _, err = timeoutBuilder.Build( + resolver.Target{ + Scheme: resolver.GetDefaultScheme(), + Endpoint: "gprc://authority/endpoint", + }, + &mockConn{}, + resolver.BuildOptions{}, + ) + if err == nil { + t.Errorf("expected error, got %v", err) } } diff --git a/transport/grpc/resolver/discovery/resolver.go b/transport/grpc/resolver/discovery/resolver.go index 6a7f6e7e8..0dc115b65 100644 --- a/transport/grpc/resolver/discovery/resolver.go +++ b/transport/grpc/resolver/discovery/resolver.go @@ -49,23 +49,23 @@ func (r *discoveryResolver) update(ins []*registry.ServiceInstance) { addrs := make([]resolver.Address, 0) endpoints := make(map[string]struct{}) for _, in := range ins { - endpoint, err := endpoint.ParseEndpoint(in.Endpoints, endpoint.Scheme("grpc", !r.insecure)) + ept, err := endpoint.ParseEndpoint(in.Endpoints, endpoint.Scheme("grpc", !r.insecure)) if err != nil { log.Errorf("[resolver] Failed to parse discovery endpoint: %v", err) continue } - if endpoint == "" { + if ept == "" { continue } // filter redundant endpoints - if _, ok := endpoints[endpoint]; ok { + if _, ok := endpoints[ept]; ok { continue } - endpoints[endpoint] = struct{}{} + endpoints[ept] = struct{}{} addr := resolver.Address{ ServerName: in.Name, Attributes: parseAttributes(in.Metadata), - Addr: endpoint, + Addr: ept, } addr.Attributes = addr.Attributes.WithValue("rawServiceInstance", in) addrs = append(addrs, addr) diff --git a/transport/grpc/resolver/discovery/resolver_test.go b/transport/grpc/resolver/discovery/resolver_test.go index 1d196382b..33d9fc27e 100644 --- a/transport/grpc/resolver/discovery/resolver_test.go +++ b/transport/grpc/resolver/discovery/resolver_test.go @@ -23,10 +23,16 @@ func (t *testClientConn) UpdateState(s resolver.State) error { type testWatch struct { err error + + count uint } func (m *testWatch) Next() ([]*registry.ServiceInstance, error) { time.Sleep(time.Millisecond * 200) + if m.count > 1 { + return nil, nil + } + m.count++ ins := []*registry.ServiceInstance{ { ID: "mock_ID", @@ -59,6 +65,7 @@ func TestWatch(t *testing.T) { cancel: cancel, insecure: false, } + r.ResolveNow(resolver.ResolveNowOptions{}) go func() { time.Sleep(time.Second * 2) r.Close() @@ -102,7 +109,10 @@ func TestWatchContextCancel(t *testing.T) { } func TestParseAttributes(t *testing.T) { - a := parseAttributes(map[string]string{"a": "b"}) + a := parseAttributes(map[string]string{ + "a": "b", + "c": "d", + }) if !reflect.DeepEqual("b", a.Value("a").(string)) { t.Errorf("expect b, got %v", a.Value("a")) } diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index 5dd779f37..3a8141c24 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -15,7 +15,6 @@ import ( pb "github.com/go-kratos/kratos/v2/internal/testdata/helloworld" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" - "google.golang.org/grpc" ) @@ -24,6 +23,36 @@ type server struct { pb.UnimplementedGreeterServer } +func (s *server) SayHelloStream(streamServer pb.Greeter_SayHelloStreamServer) error { + tctx, ok := transport.FromServerContext(streamServer.Context()) + if ok { + tctx.ReplyHeader().Set("123", "123") + } + var cnt uint + for { + in, err := streamServer.Recv() + if err != nil { + return err + } + if in.Name == "error" { + return errors.BadRequest("custom_error", fmt.Sprintf("invalid argument %s", in.Name)) + } + if in.Name == "panic" { + panic("server panic") + } + err = streamServer.Send(&pb.HelloReply{ + Message: fmt.Sprintf("hello %s", in.Name), + }) + if err != nil { + return err + } + cnt++ + if cnt > 1 { + return nil + } + } +} + // SayHello implements helloworld.GreeterServer func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { if in.Name == "error" { @@ -97,6 +126,9 @@ func testClient(t *testing.T, srv *Server) { } }), ) + defer func() { + _ = conn.Close() + }() if err != nil { t.Fatal(err) } @@ -109,7 +141,28 @@ func testClient(t *testing.T, srv *Server) { if !reflect.DeepEqual(reply.Message, "Hello kratos") { t.Errorf("expect %s, got %s", "Hello kratos", reply.Message) } - _ = conn.Close() + + streamCli, err := client.SayHelloStream(context.Background()) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = streamCli.CloseSend() + }() + err = streamCli.Send(&pb.HelloRequest{Name: "cc"}) + if err != nil { + t.Error(err) + return + } + reply, err = streamCli.Recv() + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(reply.Message, "hello cc") { + t.Errorf("expect %s, got %s", "hello cc", reply.Message) + } } func TestNetwork(t *testing.T) { diff --git a/transport/http/binding/encode.go b/transport/http/binding/encode.go index 14faf2e92..3129bbe5d 100644 --- a/transport/http/binding/encode.go +++ b/transport/http/binding/encode.go @@ -12,20 +12,19 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" ) +var reg = regexp.MustCompile(`/{[\\.\w]+}`) + // EncodeURL encode proto message to url path. func EncodeURL(pathTemplate string, msg proto.Message, needQuery bool) string { if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) { return pathTemplate } - reg := regexp.MustCompile(`/{[.\w]+}`) - if reg == nil { - return pathTemplate - } pathParams := make(map[string]struct{}) path := reg.ReplaceAllStringFunc(pathTemplate, func(in string) string { - if len(in) < 4 { //nolint:gomnd // ** explain the 4 number here :-) ** - return in - } + // it's unreachable because the reg means that must have more than one char in {} + //if len(in) < 4 { //nolint:gomnd // ** explain the 4 number here :-) ** + // return in + //} key := in[2 : len(in)-1] vars := strings.Split(key, ".") value, err := getValueByField(msg.ProtoReflect(), vars) diff --git a/transport/http/binding/encode_test.go b/transport/http/binding/encode_test.go index 9898aba1d..fdc2bb78a 100644 --- a/transport/http/binding/encode_test.go +++ b/transport/http/binding/encode_test.go @@ -14,6 +14,34 @@ func TestProtoPath(t *testing.T) { if url != `http://helloworld.Greeter/helloworld/test/sub/2233!!!` { t.Fatalf("proto path not expected!actual: %s ", url) } + url = EncodeURL("http://helloworld.Greeter/helloworld/{name}/sub/{sub.name}", nil, false) + fmt.Println(url) + if url != "http://helloworld.Greeter/helloworld/{name}/sub/{sub.name}" { + t.Fatalf("proto path not expected!actual: %s ", url) + } + url = EncodeURL("http://helloworld.Greeter/helloworld/{}/sub/{sub.name}", &binding.HelloRequest{Name: "test", Sub: &binding.Sub{Name: "hello"}}, false) + fmt.Println(url) + if url != "http://helloworld.Greeter/helloworld/{}/sub/hello" { + t.Fatalf("proto path not expected!actual: %s ", url) + } + url = EncodeURL("http://helloworld.Greeter/helloworld/{}/sub/{sub.name.cc}", &binding.HelloRequest{Name: "test", Sub: &binding.Sub{Name: "hello"}}, false) + fmt.Println(url) + if url != "http://helloworld.Greeter/helloworld/{}/sub/{sub.name.cc}" { + t.Fatalf("proto path not expected!actual: %s ", url) + } + + url = EncodeURL( + "http://helloworld.Greeter/helloworld/{}/sub/{test_repeated.1}", + &binding.HelloRequest{ + Name: "test", Sub: &binding.Sub{Name: "hello"}, + TestRepeated: []string{"123", "456"}, + }, + false, + ) + fmt.Println(url) + if url != "http://helloworld.Greeter/helloworld/{}/sub/{test_repeated.1}" { + t.Fatalf("proto path not expected!actual: %s ", url) + } url = EncodeURL("http://helloworld.Greeter/helloworld/{name}/sub/{sub.naming}", &binding.HelloRequest{Name: "test", Sub: &binding.Sub{Name: "5566!!!"}}, false) fmt.Println(url) diff --git a/transport/http/client_test.go b/transport/http/client_test.go index 731d111dc..c768b364b 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "log" nethttp "net/http" "reflect" "strconv" @@ -17,6 +18,7 @@ import ( kratosErrors "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" + "github.com/go-kratos/kratos/v2/selector" ) type mockRoundTripper struct{} @@ -25,6 +27,21 @@ func (rt *mockRoundTripper) RoundTrip(req *nethttp.Request) (resp *nethttp.Respo return } +type mockCallOption struct { + needErr bool +} + +func (x *mockCallOption) before(info *callInfo) error { + if x.needErr { + return fmt.Errorf("option need return err") + } + return nil +} + +func (x *mockCallOption) after(info *callInfo, attempt *csAttempt) { + log.Println("run in mockCallOption.after") +} + func TestWithTransport(t *testing.T) { ov := &mockRoundTripper{} o := WithTransport(ov) @@ -165,6 +182,16 @@ func TestWithDiscovery(t *testing.T) { } } +func TestWithSelector(t *testing.T) { + ov := &selector.Default{} + o := WithSelector(ov) + co := &clientOptions{} + o(co) + if !reflect.DeepEqual(co.selector, ov) { + t.Errorf("expected selector to be %v, got %v", ov, co.selector) + } +} + func TestDefaultRequestEncoder(t *testing.T) { req1 := &nethttp.Request{ Header: make(nethttp.Header), @@ -284,10 +311,6 @@ func TestNewClient(t *testing.T) { if err != nil { t.Error(err) } - client, err := NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("discovery:///go-kratos")) - if err != nil { - t.Error(err) - } _, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("discovery:///go-kratos")) if err != nil { t.Error(err) @@ -296,13 +319,43 @@ func TestNewClient(t *testing.T) { if err != nil { t.Error(err) } + _, err = NewClient(context.Background(), WithEndpoint("127.0.0.1:8888:xxxxa")) + if err == nil { + t.Error("except a parseTarget error") + } _, err = NewClient(context.Background(), WithDiscovery(&mockDiscovery{}), WithEndpoint("https://go-kratos.dev/")) if err == nil { t.Error("err should not be equal to nil") } - err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}) + client, err := NewClient( + context.Background(), + WithDiscovery(&mockDiscovery{}), + WithEndpoint("discovery:///go-kratos"), + WithMiddleware(func(handler middleware.Handler) middleware.Handler { + t.Logf("handle in middleware") + return func(ctx context.Context, req interface{}) (interface{}, error) { + return handler(ctx, req) + } + }), + ) + if err != nil { + t.Error(err) + } + + err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{}) if err == nil { t.Error("err should not be equal to nil") } + err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}, &mockCallOption{needErr: true}) + if err == nil { + t.Error("err should be equal to callOption err") + } + client.opts.encoder = func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) { + return nil, fmt.Errorf("mock test encoder error") + } + err = client.Invoke(context.Background(), "POST", "/go", map[string]string{"name": "kratos"}, nil, EmptyCallOption{}) + if err == nil { + t.Error("err should be equal to encoder error") + } } diff --git a/transport/http/context_test.go b/transport/http/context_test.go index 85a7aaaa0..3b7466bf0 100644 --- a/transport/http/context_test.go +++ b/transport/http/context_test.go @@ -3,6 +3,8 @@ package http import ( "bytes" "context" + "errors" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -90,6 +92,11 @@ func TestContextResponse(t *testing.T) { if err != nil { t.Errorf("expected %v, got %v", nil, err) } + needErr := fmt.Errorf("some error") + err = w.Returns(map[string]string{}, needErr) + if !errors.Is(err, needErr) { + t.Errorf("expected %v, got %v", needErr, err) + } } func TestContextBindQuery(t *testing.T) { diff --git a/transport/http/resolver_test.go b/transport/http/resolver_test.go index 419bae44a..27473d12b 100644 --- a/transport/http/resolver_test.go +++ b/transport/http/resolver_test.go @@ -61,22 +61,42 @@ func (m *mockRebalancer) Apply(nodes []selector.Node) {} type mockDiscoveries struct { isSecure bool + nextErr bool + stopErr bool } func (d *mockDiscoveries) GetService(ctx context.Context, serviceName string) ([]*registry.ServiceInstance, error) { return nil, nil } +const errServiceName = "needErr" + func (d *mockDiscoveries) Watch(ctx context.Context, serviceName string) (registry.Watcher, error) { - return &mockWatch{isSecure: d.isSecure}, nil + if serviceName == errServiceName { + return nil, fmt.Errorf("mock test service name watch err") + } + return &mockWatch{ctx: ctx, isSecure: d.isSecure, nextErr: d.nextErr, stopErr: d.stopErr}, nil } type mockWatch struct { + ctx context.Context + isSecure bool count int + + nextErr bool + stopErr bool } func (m *mockWatch) Next() ([]*registry.ServiceInstance, error) { + select { + case <-m.ctx.Done(): + return nil, m.ctx.Err() + default: + } + if m.nextErr { + return nil, errors.New("mock test error") + } if m.count == 1 { return nil, errors.New("mock test error") } @@ -95,21 +115,61 @@ func (m *mockWatch) Next() ([]*registry.ServiceInstance, error) { } func (m *mockWatch) Stop() error { + if m.stopErr { + return fmt.Errorf("mock test error") + } + // 标记 next 需要报错 + m.nextErr = true return nil } func TestResolver(t *testing.T) { - ta := &Target{ - Scheme: "http", - Authority: "", - Endpoint: "discovery://helloworld", + ta, err := parseTarget("discovery://helloworld", true) + if err != nil { + t.Errorf("parse err %v", err) + return + } + + // 异步 无需报错 + _, err = newResolver(context.Background(), &mockDiscoveries{true, false, false}, ta, &mockRebalancer{}, false, false) + if err != nil { + t.Errorf("expect %v, got %v", nil, err) } - _, err := newResolver(context.Background(), &mockDiscoveries{true}, ta, &mockRebalancer{}, false, false) + + // 同步 一切正常运行 + _, err = newResolver(context.Background(), &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, true, true) if err != nil { t.Errorf("expect %v, got %v", nil, err) } - _, err = newResolver(context.Background(), &mockDiscoveries{false}, ta, &mockRebalancer{}, true, true) + + // 同步 但是 next 出错 以及 stop 出错 + _, err = newResolver(context.Background(), &mockDiscoveries{false, true, true}, ta, &mockRebalancer{}, true, true) + if err == nil { + t.Errorf("expect err, got nil") + } + + // 同步 service name watch 失败 + _, err = newResolver(context.Background(), &mockDiscoveries{false, true, true}, &Target{ + Scheme: "discovery", + Endpoint: errServiceName, + }, &mockRebalancer{}, true, true) + if err == nil { + t.Errorf("expect err, got nil") + } + + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + // 此处应该打印出来 context.Canceled + r, err := newResolver(cancelCtx, &mockDiscoveries{false, false, false}, ta, &mockRebalancer{}, false, false) if err != nil { t.Errorf("expect %v, got %v", nil, err) } + _ = r.Close() + + // 同步 但是服务取消,此时需要报错 + _, err = newResolver(cancelCtx, &mockDiscoveries{false, false, true}, ta, &mockRebalancer{}, true, true) + if err == nil { + t.Errorf("expect ctx cancel err, got nil") + } } diff --git a/transport/http/server.go b/transport/http/server.go index 64c42df11..a0d4f14f3 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -22,6 +22,7 @@ import ( var ( _ transport.Server = (*Server)(nil) _ transport.Endpointer = (*Server)(nil) + _ http.Handler = (*Server)(nil) ) // ServerOption is an HTTP server option. diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 34f2b8553..ea60d0cf4 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -25,14 +25,28 @@ type testData struct { Path string `json:"path"` } +// handleFuncWrapper is a wrapper for http.HandlerFunc to implement http.Handler +type handleFuncWrapper struct { + fn http.HandlerFunc +} + +func (x *handleFuncWrapper) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + x.fn.ServeHTTP(writer, request) +} + +func newHandleFuncWrapper(fn http.HandlerFunc) http.Handler { + return &handleFuncWrapper{fn: fn} +} + func TestServer(t *testing.T) { fn := func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI}) } ctx := context.Background() srv := NewServer() - srv.HandleFunc("/index", fn) + srv.Handle("/index", newHandleFuncWrapper(fn)) srv.HandleFunc("/index/{id:[0-9]+}", fn) + srv.HandlePrefix("/test/prefix", newHandleFuncWrapper(fn)) srv.HandleHeader("content-type", "application/grpc-web+json", func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(testData{Path: r.RequestURI}) }) @@ -55,6 +69,7 @@ func TestServer(t *testing.T) { testHeader(t, srv) testClient(t, srv) testAccept(t, srv) + time.Sleep(time.Second) if srv.Stop(ctx) != nil { t.Errorf("expected nil got %v", srv.Stop(ctx)) } @@ -121,20 +136,21 @@ func testClient(t *testing.T, srv *Server) { path string code int }{ - {"GET", "/index", 200}, - {"PUT", "/index", 200}, - {"POST", "/index", 200}, - {"PATCH", "/index", 200}, - {"DELETE", "/index", 200}, + {"GET", "/index", http.StatusOK}, + {"PUT", "/index", http.StatusOK}, + {"POST", "/index", http.StatusOK}, + {"PATCH", "/index", http.StatusOK}, + {"DELETE", "/index", http.StatusOK}, - {"GET", "/index/1", 200}, - {"PUT", "/index/1", 200}, - {"POST", "/index/1", 200}, - {"PATCH", "/index/1", 200}, - {"DELETE", "/index/1", 200}, + {"GET", "/index/1", http.StatusOK}, + {"PUT", "/index/1", http.StatusOK}, + {"POST", "/index/1", http.StatusOK}, + {"PATCH", "/index/1", http.StatusOK}, + {"DELETE", "/index/1", http.StatusOK}, - {"GET", "/index/notfound", 404}, - {"GET", "/errors/cause", 400}, + {"GET", "/index/notfound", http.StatusNotFound}, + {"GET", "/errors/cause", http.StatusBadRequest}, + {"GET", "/test/prefix/123111", http.StatusOK}, } e, err := srv.Endpoint() if err != nil { @@ -307,6 +323,15 @@ func TestTLSConfig(t *testing.T) { } } +func TestStrictSlash(t *testing.T) { + o := &Server{} + v := true + StrictSlash(v)(o) + if !reflect.DeepEqual(v, o.strictSlash) { + t.Errorf("expected %v got %v", v, o.tlsConf) + } +} + func TestListener(t *testing.T) { lis := &net.TCPListener{} s := &Server{} diff --git a/transport/transport_test.go b/transport/transport_test.go index c29c51e4b..b83c8ee3c 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -55,6 +55,9 @@ func TestServerTransport(t *testing.T) { if mtr == nil { t.Errorf("expected:%v got:%v", nil, mtr) } + if mtr.Kind().String() != KindGRPC.String() { + t.Errorf("expected:%v got:%v", KindGRPC.String(), mtr.Kind().String()) + } if !reflect.DeepEqual(mtr.endpoint, "test_endpoint") { t.Errorf("expected:%v got:%v", "test_endpoint", mtr.endpoint) }