feat(middleware): add selector matcher (#2239)
* feat(middleware): add selector matcher Co-authored-by: chenzhihui <chenzhihui@bilibili.com>pull/2240/head
parent
377356d04d
commit
f3b0da3f04
@ -0,0 +1,62 @@ |
|||||||
|
package matcher |
||||||
|
|
||||||
|
import ( |
||||||
|
"sort" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware" |
||||||
|
) |
||||||
|
|
||||||
|
// Matcher is a middleware matcher.
|
||||||
|
type Matcher interface { |
||||||
|
Use(ms ...middleware.Middleware) |
||||||
|
Add(selector string, ms ...middleware.Middleware) |
||||||
|
Match(operation string) []middleware.Middleware |
||||||
|
} |
||||||
|
|
||||||
|
// New new a middleware matcher.
|
||||||
|
func New() Matcher { |
||||||
|
return &matcher{ |
||||||
|
matchs: make(map[string][]middleware.Middleware), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type matcher struct { |
||||||
|
prefix []string |
||||||
|
defaults []middleware.Middleware |
||||||
|
matchs map[string][]middleware.Middleware |
||||||
|
} |
||||||
|
|
||||||
|
func (m *matcher) Use(ms ...middleware.Middleware) { |
||||||
|
m.defaults = ms |
||||||
|
} |
||||||
|
|
||||||
|
func (m *matcher) Add(selector string, ms ...middleware.Middleware) { |
||||||
|
if strings.HasSuffix(selector, "*") { |
||||||
|
selector = strings.TrimSuffix(selector, "*") |
||||||
|
m.prefix = append(m.prefix, selector) |
||||||
|
// sort the prefix:
|
||||||
|
// - /foo/bar
|
||||||
|
// - /foo
|
||||||
|
sort.Slice(m.prefix, func(i, j int) bool { |
||||||
|
return m.prefix[i] > m.prefix[j] |
||||||
|
}) |
||||||
|
} |
||||||
|
m.matchs[selector] = ms |
||||||
|
} |
||||||
|
|
||||||
|
func (m *matcher) Match(operation string) []middleware.Middleware { |
||||||
|
ms := make([]middleware.Middleware, 0, len(m.defaults)) |
||||||
|
if len(m.defaults) > 0 { |
||||||
|
ms = append(ms, m.defaults...) |
||||||
|
} |
||||||
|
if next, ok := m.matchs[operation]; ok { |
||||||
|
return append(ms, next...) |
||||||
|
} |
||||||
|
for _, prefix := range m.prefix { |
||||||
|
if strings.HasPrefix(operation, prefix) { |
||||||
|
return append(ms, m.matchs[prefix]...) |
||||||
|
} |
||||||
|
} |
||||||
|
return ms |
||||||
|
} |
@ -0,0 +1,62 @@ |
|||||||
|
package matcher |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware" |
||||||
|
) |
||||||
|
|
||||||
|
func logging(module string) middleware.Middleware { |
||||||
|
return func(handler middleware.Handler) middleware.Handler { |
||||||
|
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { |
||||||
|
return module, nil |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func equal(ms []middleware.Middleware, modules ...string) bool { |
||||||
|
if len(ms) == 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
for i, m := range ms { |
||||||
|
x, _ := m(nil)(nil, nil) |
||||||
|
if x != modules[i] { |
||||||
|
return false |
||||||
|
} |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
func TestMatcher(t *testing.T) { |
||||||
|
m := New() |
||||||
|
m.Use(logging("logging")) |
||||||
|
m.Add("*", logging("*")) |
||||||
|
m.Add("/foo/*", logging("foo/*")) |
||||||
|
m.Add("/foo/bar/*", logging("foo/bar/*")) |
||||||
|
m.Add("/foo/bar", logging("foo/bar")) |
||||||
|
|
||||||
|
if ms := m.Match("/"); len(ms) != 2 { |
||||||
|
t.Fatal("not equal") |
||||||
|
} else if !equal(ms, "logging", "*") { |
||||||
|
t.Fatal("not equal") |
||||||
|
} |
||||||
|
|
||||||
|
if ms := m.Match("/foo/xxx"); len(ms) != 2 { |
||||||
|
t.Fatal("not equal") |
||||||
|
} else if !equal(ms, "logging", "foo/*") { |
||||||
|
t.Fatal("not equal") |
||||||
|
} |
||||||
|
|
||||||
|
if ms := m.Match("/foo/bar"); len(ms) != 2 { |
||||||
|
t.Fatal("not equal") |
||||||
|
} else if !equal(ms, "logging", "foo/bar") { |
||||||
|
t.Fatal("not equal") |
||||||
|
} |
||||||
|
|
||||||
|
if ms := m.Match("/foo/bar/x"); len(ms) != 2 { |
||||||
|
t.Fatal("not equal") |
||||||
|
} else if !equal(ms, "logging", "foo/bar/*") { |
||||||
|
t.Fatal("not equal") |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue