diff --git a/cmd/protoc-gen-go-http/http.go b/cmd/protoc-gen-go-http/http.go index 8d799a3be..f52a20772 100644 --- a/cmd/protoc-gen-go-http/http.go +++ b/cmd/protoc-gen-go-http/http.go @@ -24,7 +24,7 @@ const ( var methodSets = make(map[string]int) // generateFile generates a _http.pb.go file containing kratos errors definitions. -func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool) *protogen.GeneratedFile { +func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool, omitemptyPrefix string) *protogen.GeneratedFile { if len(file.Services) == 0 || (omitempty && !hasHTTPRule(file.Services)) { return nil } @@ -42,12 +42,12 @@ func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool) *pr g.P() g.P("package ", file.GoPackageName) g.P() - generateFileContent(gen, file, g, omitempty) + generateFileContent(gen, file, g, omitempty, omitemptyPrefix) return g } // generateFileContent generates the kratos errors definitions, excluding the package statement. -func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, omitempty bool) { +func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, omitempty bool, omitemptyPrefix string) { if len(file.Services) == 0 { return } @@ -59,11 +59,11 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen. g.P() for _, service := range file.Services { - genService(gen, file, g, service, omitempty) + genService(gen, file, g, service, omitempty, omitemptyPrefix) } } -func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, omitempty bool) { +func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, omitempty bool, omitemptyPrefix string) { if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { g.P("//") g.P(deprecationComment) @@ -81,11 +81,11 @@ func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFi rule, ok := proto.GetExtension(method.Desc.Options(), annotations.E_Http).(*annotations.HttpRule) if rule != nil && ok { for _, bind := range rule.AdditionalBindings { - sd.Methods = append(sd.Methods, buildHTTPRule(g, method, bind)) + sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, bind, omitemptyPrefix)) } - sd.Methods = append(sd.Methods, buildHTTPRule(g, method, rule)) + sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, rule, omitemptyPrefix)) } else if !omitempty { - path := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) + path := fmt.Sprintf("%s/%s/%s", omitemptyPrefix, service.Desc.FullName(), method.Desc.Name()) sd.Methods = append(sd.Methods, buildMethodDesc(g, method, http.MethodPost, path)) } } @@ -109,7 +109,7 @@ func hasHTTPRule(services []*protogen.Service) bool { return false } -func buildHTTPRule(g *protogen.GeneratedFile, m *protogen.Method, rule *annotations.HttpRule) *methodDesc { +func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *protogen.Method, rule *annotations.HttpRule, omitemptyPrefix string) *methodDesc { var ( path string method string @@ -137,6 +137,12 @@ func buildHTTPRule(g *protogen.GeneratedFile, m *protogen.Method, rule *annotati path = pattern.Custom.Path method = pattern.Custom.Kind } + if method == "" { + method = http.MethodPost + } + if path == "" { + path = fmt.Sprintf("%s/%s/%s", omitemptyPrefix, service.Desc.FullName(), m.Desc.Name()) + } body = rule.Body responseBody = rule.ResponseBody md := buildMethodDesc(g, m, method, path) diff --git a/cmd/protoc-gen-go-http/main.go b/cmd/protoc-gen-go-http/main.go index a18fbddc4..3752019ca 100644 --- a/cmd/protoc-gen-go-http/main.go +++ b/cmd/protoc-gen-go-http/main.go @@ -9,8 +9,9 @@ import ( ) var ( - showVersion = flag.Bool("version", false, "print the version and exit") - omitempty = flag.Bool("omitempty", true, "omit if google.api is empty") + showVersion = flag.Bool("version", false, "print the version and exit") + omitempty = flag.Bool("omitempty", true, "omit if google.api is empty") + omitemptyPrefix = flag.String("omitempty_prefix", "", "omit if google.api is empty") ) func main() { @@ -27,7 +28,7 @@ func main() { if !f.Generate { continue } - generateFile(gen, f, *omitempty) + generateFile(gen, f, *omitempty, *omitemptyPrefix) } return nil }) diff --git a/encoding/form/form_test.go b/encoding/form/form_test.go index 013edbd1d..27a52c4b1 100644 --- a/encoding/form/form_test.go +++ b/encoding/form/form_test.go @@ -99,7 +99,7 @@ func TestProtoEncodeDecode(t *testing.T) { Price: 11.23, D: 22.22, Byte: []byte("123"), - Map: map[string]string{"kratos": "https://go-kratos.dev/"}, + Map: map[string]string{"kratos": "https://go-kratos.dev/", "kratos_start": "https://go-kratos.dev/en/docs/getting-started/start/"}, Timestamp: ×tamppb.Timestamp{Seconds: 20, Nanos: 2}, Duration: &durationpb.Duration{Seconds: 120, Nanos: 22}, @@ -119,7 +119,8 @@ func TestProtoEncodeDecode(t *testing.T) { t.Fatal(err) } if "a=19&age=18&b=true&bool=false&byte=MTIz&bytes=MTIz&count=3&d=22.22&double=12.33&duration="+ - "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&"+ + "2m0.000000022s&field=1%2C2&float=12.34&id=2233&int32=32&int64=64&"+ + "map%5Bkratos%5D=https%3A%2F%2Fgo-kratos.dev%2F&map%5Bkratos_start%5D=https%3A%2F%2Fgo-kratos.dev%2Fen%2Fdocs%2Fgetting-started%2Fstart%2F&"+ "numberOne=2233&price=11.23&sex=woman&simples=3344&simples=5566&string=go-kratos"+ "×tamp=1970-01-01T00%3A00%3A20.000000002Z&uint32=32&uint64=64&very_simple.component=5566" != string(content) { t.Errorf("rawpath is not equal to %s", content) @@ -153,6 +154,14 @@ func TestProtoEncodeDecode(t *testing.T) { if "5566" != in2.Simples[1] { t.Errorf("expect %v, got %v", "5566", in2.Simples[1]) } + if l := len(in2.GetMap()); l != 2 { + t.Fatalf("in2.Map length want: %d, got: %d", 2, l) + } + for key, val := range in.GetMap() { + if in2Val := in2.GetMap()[key]; in2Val != val { + t.Errorf("%s want: %q, got: %q", "map["+key+"]", val, in2Val) + } + } } func TestDecodeStructPb(t *testing.T) { @@ -181,7 +190,7 @@ func TestDecodeBytesValuePb(t *testing.T) { content := "bytes=" + val in2 := &complex.Complex{} if err := encoding.GetCodec(Name).Unmarshal([]byte(content), in2); err != nil { - t.Error(err) + t.Fatal(err) } if url != string(in2.Bytes.Value) { t.Errorf("except %s, got %s", val, in2.Bytes.Value) diff --git a/encoding/form/proto_decode.go b/encoding/form/proto_decode.go index 1620162b4..28421fefb 100644 --- a/encoding/form/proto_decode.go +++ b/encoding/form/proto_decode.go @@ -20,6 +20,8 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" ) +var ErrInvalidFormatMapKey = errors.New("invalid formatting for map key") + // DecodeValues decode url value into proto message. func DecodeValues(msg proto.Message, values url.Values) error { for key, values := range values { @@ -77,13 +79,23 @@ func populateFieldValues(v protoreflect.Message, fieldPath []string, values []st } func getFieldDescriptor(v protoreflect.Message, fieldName string) protoreflect.FieldDescriptor { - fields := v.Descriptor().Fields() - var fd protoreflect.FieldDescriptor - if fd = getDescriptorByFieldAndName(fields, fieldName); fd == nil { - if v.Descriptor().FullName() == structMessageFullname { + var ( + fields = v.Descriptor().Fields() + fd = getDescriptorByFieldAndName(fields, fieldName) + ) + if fd == nil { + switch { + case v.Descriptor().FullName() == structMessageFullname: fd = fields.ByNumber(structFieldsFieldNumber) - } else if len(fieldName) > 2 && strings.HasSuffix(fieldName, "[]") { + case len(fieldName) > 2 && strings.HasSuffix(fieldName, "[]"): fd = getDescriptorByFieldAndName(fields, strings.TrimSuffix(fieldName, "[]")) + default: + // If the type is map, you get the string "map[kratos]", where "map" is a field of proto and "kratos" is a key of map + field, _, err := parseURLQueryMapKey(fieldName) + if err != nil { + break + } + fd = getDescriptorByFieldAndName(fields, field) } } return fd @@ -121,14 +133,20 @@ func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.Li } func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, fieldPath []string, values []string) error { - // post sub key. - nkey := len(fieldPath) - 1 - key, err := parseField(fd.MapKey(), fieldPath[nkey]) + var ( + nKey = len(fieldPath) - 1 // post sub key + vKey = len(values) - 1 + fieldName = fieldPath[nKey] + ) + _, keyName, err := parseURLQueryMapKey(fieldName) + if err != nil { + return err + } + key, err := parseField(fd.MapKey(), keyName) if err != nil { return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err) } - vkey := len(values) - 1 - value, err := parseField(fd.MapValue(), values[vkey]) + value, err := parseField(fd.MapValue(), values[vKey]) if err != nil { return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err) } @@ -331,3 +349,16 @@ func jsonSnakeCase(s string) string { func isASCIIUpper(c byte) bool { return 'A' <= c && c <= 'Z' } + +// parseURLQueryMapKey parse the url.Values the field name and key name of the value map type key +// for example: convert "map[key]" to "map" and "key" +func parseURLQueryMapKey(key string) (string, string, error) { + var ( + startIndex = strings.IndexByte(key, '[') + endIndex = strings.IndexByte(key, ']') + ) + if startIndex <= 0 || startIndex >= endIndex || len(key) != endIndex+1 { + return "", "", ErrInvalidFormatMapKey + } + return key[:startIndex], key[startIndex+1 : endIndex], nil +} diff --git a/encoding/form/proto_decode_test.go b/encoding/form/proto_decode_test.go index fbe234334..986d7c333 100644 --- a/encoding/form/proto_decode_test.go +++ b/encoding/form/proto_decode_test.go @@ -87,7 +87,7 @@ func TestPopulateMapField(t *testing.T) { comp := &complex.Complex{} field := getFieldDescriptor(comp.ProtoReflect(), "map") // Fill the comp map field with the url query values - err = populateMapField(field, comp.ProtoReflect().Mutable(field).Map(), []string{"kratos"}, query["map[kratos]"]) + err = populateMapField(field, comp.ProtoReflect().Mutable(field).Map(), []string{"map[kratos]"}, query["map[kratos]"]) if err != nil { t.Fatal(err) } @@ -215,3 +215,60 @@ func TestIsASCIIUpper(t *testing.T) { }) } } + +func TestParseURLQueryMapKey(t *testing.T) { + tests := []struct { + fieldName string + field string + fieldKey string + err error + }{ + { + fieldName: "map[kratos]", field: "map", fieldKey: "kratos", err: nil, + }, + { + fieldName: "map[]", field: "map", fieldKey: "", err: nil, + }, + { + fieldName: "", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[[]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map[kratos]=", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[kratos]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "map[", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "]kratos[", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "[kratos", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + { + fieldName: "kratos]", field: "", fieldKey: "", err: ErrInvalidFormatMapKey, + }, + } + for _, test := range tests { + t.Run(test.fieldName, func(t *testing.T) { + fieldName, fieldKey, err := parseURLQueryMapKey(test.fieldName) + if test.err != err { + t.Fatalf("want: %s, got: %s", test.err, err) + } + if test.field != fieldName { + t.Errorf("want: %s, got: %s", test.field, fieldName) + } + if test.fieldKey != fieldKey { + t.Errorf("want: %s, got: %s", test.fieldKey, fieldKey) + } + }) + } +} diff --git a/log/filter.go b/log/filter.go index 993fa926e..7ef96655d 100644 --- a/log/filter.go +++ b/log/filter.go @@ -1,5 +1,7 @@ package log +import "context" + // FilterOption is filter option. type FilterOption func(*Filter) @@ -39,6 +41,7 @@ func FilterFunc(f func(level Level, keyvals ...interface{}) bool) FilterOption { // Filter is a logger filter. type Filter struct { + ctx context.Context logger Logger level Level key map[interface{}]struct{} @@ -67,6 +70,9 @@ func (f *Filter) Log(level Level, keyvals ...interface{}) error { // prefixkv contains the slice of arguments defined as prefixes during the log initialization var prefixkv []interface{} l, ok := f.logger.(*logger) + if ok { + l.ctx = f.ctx + } if ok && len(l.prefix) > 0 { prefixkv = make([]interface{}, 0, len(l.prefix)) prefixkv = append(prefixkv, l.prefix...) diff --git a/log/filter_test.go b/log/filter_test.go index 9258400c3..679fbe8c4 100644 --- a/log/filter_test.go +++ b/log/filter_test.go @@ -2,7 +2,9 @@ package log import ( "bytes" + "context" "io" + "strings" "testing" ) @@ -140,3 +142,33 @@ func testFilterFuncWithLoggerPrefix(level Level, keyvals ...interface{}) bool { } return false } + +func TestFilterWithContext(t *testing.T) { + ctxKey := struct{}{} + ctxValue := "filter test value" + + v1 := func() Valuer { + return func(ctx context.Context) interface{} { + return ctx.Value(ctxKey) + } + } + + info := &bytes.Buffer{} + + logger := With(NewStdLogger(info), "request_id", v1()) + filter := NewFilter(logger, FilterLevel(LevelError)) + + ctx := context.WithValue(context.Background(), ctxKey, ctxValue) + + _ = WithContext(ctx, filter).Log(LevelInfo, "kind", "test") + + if info.String() != "" { + t.Error("filter is not woring") + return + } + + _ = WithContext(ctx, filter).Log(LevelError, "kind", "test") + if !strings.Contains(info.String(), ctxValue) { + t.Error("don't read ctx value") + } +} diff --git a/log/log.go b/log/log.go index 5022ba4bb..38cef8052 100644 --- a/log/log.go +++ b/log/log.go @@ -51,13 +51,27 @@ func With(l Logger, kv ...interface{}) Logger { // to ctx. The provided ctx must be non-nil. func WithContext(ctx context.Context, l Logger) Logger { c, ok := l.(*logger) - if !ok { - return &logger{logger: l, ctx: ctx} + if ok { + return &logger{ + logger: c.logger, + prefix: c.prefix, + hasValuer: c.hasValuer, + ctx: ctx, + } } - return &logger{ - logger: c.logger, - prefix: c.prefix, - hasValuer: c.hasValuer, - ctx: ctx, + + f, ok := l.(*Filter) + if ok { + f.ctx = ctx + return &Filter{ + ctx: ctx, + logger: f.logger, + level: f.level, + key: f.key, + value: f.value, + filter: f.filter, + } } + + return &logger{logger: l, ctx: ctx} }