diff --git a/cmd/protoc-gen-go-http/go.mod b/cmd/protoc-gen-go-http/go.mod index fe4126408..d38285e8a 100644 --- a/cmd/protoc-gen-go-http/go.mod +++ b/cmd/protoc-gen-go-http/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/go-kratos/kratos/v2 v2.1.3 + github.com/stretchr/testify v1.7.0 google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67 google.golang.org/protobuf v1.27.1 ) diff --git a/cmd/protoc-gen-go-http/http.go b/cmd/protoc-gen-go-http/http.go index 4d6c80d92..fef91d7a2 100644 --- a/cmd/protoc-gen-go-http/http.go +++ b/cmd/protoc-gen-go-http/http.go @@ -3,6 +3,7 @@ package main import ( "fmt" "os" + "regexp" "strings" "google.golang.org/protobuf/reflect/protoreflect" @@ -157,9 +158,14 @@ func buildHTTPRule(g *protogen.GeneratedFile, m *protogen.Method, rule *annotati func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path string) *methodDesc { defer func() { methodSets[m.GoName]++ }() - vars := buildPathVars(m, path) + + vars := buildPathVars(path) fields := m.Input.Desc.Fields() - for _, v := range vars { + + for v, s := range vars { + if s != nil { + path = replacePath(v, *s, path) + } for _, field := range strings.Split(v, ".") { if strings.TrimSpace(field) == "" { continue @@ -192,16 +198,33 @@ func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path } } -func buildPathVars(method *protogen.Method, path string) (res []string) { - for _, v := range strings.Split(path, "/") { - if strings.HasPrefix(v, "{") && strings.HasSuffix(v, "}") { - name := strings.TrimRight(strings.TrimLeft(v, "{"), "}") - res = append(res, name) +func buildPathVars(path string) (res map[string]*string) { + res = make(map[string]*string) + pattern := regexp.MustCompile(`(?i){([a-z\.0-9_\s]*)=?([^{}]*)}`) + matches := pattern.FindAllStringSubmatch(path, -1) + for _, m := range matches { + name := strings.TrimSpace(m[1]) + if len(name) > 1 && len(m[2]) > 0 { + res[name] = &m[2] + } else { + res[name] = nil } } return } +func replacePath(name string, value string, path string) string { + pattern := regexp.MustCompile(fmt.Sprintf(`(?i){([\s]*%s[\s]*)=`, name)) + idx := pattern.FindStringIndex(path) + if len(idx) > 0 { + path = fmt.Sprintf("%s{%s:%s}", + path[:idx[0]], // The start of the match + name, + strings.ReplaceAll(value, "*", ".*")) + } + return path +} + func camelCaseVars(s string) string { vars := make([]string, 0) subs := strings.Split(s, ".") diff --git a/cmd/protoc-gen-go-http/http_test.go b/cmd/protoc-gen-go-http/http_test.go new file mode 100644 index 000000000..f28a7f410 --- /dev/null +++ b/cmd/protoc-gen-go-http/http_test.go @@ -0,0 +1,54 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNoParameters(t *testing.T) { + path := "/test/noparams" + m := buildPathVars(path) + assert.Emptyf(t, m, "Map should be empty") +} + +func TestSingleParam(t *testing.T) { + path := "/test/{message.id}" + m := buildPathVars(path) + assert.Len(t, m, 1) + assert.Empty(t, m["message.id"]) +} + +func TestTwoParametersReplacement(t *testing.T) { + path := "/test/{message.id}/{message.name=messages/*}" + m := buildPathVars(path) + assert.Len(t, m, 2) + assert.Empty(t, m["message.id"]) + assert.NotEmpty(t, m["message.name"]) + assert.Equal(t, *m["message.name"], "messages/*") +} + +func TestNoReplacePath(t *testing.T) { + path := "/test/{message.id}" + assert.Equal(t, path, replacePath("message.id", "", path)) + + path = "/test/{message.id=test}" + assert.Equal(t, "/test/{message.id:test}", replacePath("message.id", "test", path)) +} + +func TestReplacePath(t *testing.T) { + path := "/test/{message.id}/{message.name=messages/*}" + newPath := replacePath("message.name", "messages/*", path) + assert.Equal(t, "/test/{message.id}/{message.name:messages/.*}", newPath) +} + +func TestIteration(t *testing.T) { + path := "/test/{message.id}/{message.name=messages/*}" + vars := buildPathVars(path) + for v, s := range vars { + if s != nil { + path = replacePath(v, *s, path) + } + } + assert.Equal(t, "/test/{message.id}/{message.name:messages/.*}", path) +}