feat(middleware): add selector matcher (#2239)
* feat(middleware): add selector matcher Co-authored-by: chenzhihui <chenzhihui@bilibili.com>pull/2240/head
parent
377356d04d
commit
f3b0da3f04
@ -0,0 +1,62 @@ |
||||
package matcher |
||||
|
||||
import ( |
||||
"sort" |
||||
"strings" |
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
) |
||||
|
||||
// Matcher is a middleware matcher.
|
||||
type Matcher interface { |
||||
Use(ms ...middleware.Middleware) |
||||
Add(selector string, ms ...middleware.Middleware) |
||||
Match(operation string) []middleware.Middleware |
||||
} |
||||
|
||||
// New new a middleware matcher.
|
||||
func New() Matcher { |
||||
return &matcher{ |
||||
matchs: make(map[string][]middleware.Middleware), |
||||
} |
||||
} |
||||
|
||||
type matcher struct { |
||||
prefix []string |
||||
defaults []middleware.Middleware |
||||
matchs map[string][]middleware.Middleware |
||||
} |
||||
|
||||
func (m *matcher) Use(ms ...middleware.Middleware) { |
||||
m.defaults = ms |
||||
} |
||||
|
||||
func (m *matcher) Add(selector string, ms ...middleware.Middleware) { |
||||
if strings.HasSuffix(selector, "*") { |
||||
selector = strings.TrimSuffix(selector, "*") |
||||
m.prefix = append(m.prefix, selector) |
||||
// sort the prefix:
|
||||
// - /foo/bar
|
||||
// - /foo
|
||||
sort.Slice(m.prefix, func(i, j int) bool { |
||||
return m.prefix[i] > m.prefix[j] |
||||
}) |
||||
} |
||||
m.matchs[selector] = ms |
||||
} |
||||
|
||||
func (m *matcher) Match(operation string) []middleware.Middleware { |
||||
ms := make([]middleware.Middleware, 0, len(m.defaults)) |
||||
if len(m.defaults) > 0 { |
||||
ms = append(ms, m.defaults...) |
||||
} |
||||
if next, ok := m.matchs[operation]; ok { |
||||
return append(ms, next...) |
||||
} |
||||
for _, prefix := range m.prefix { |
||||
if strings.HasPrefix(operation, prefix) { |
||||
return append(ms, m.matchs[prefix]...) |
||||
} |
||||
} |
||||
return ms |
||||
} |
@ -0,0 +1,62 @@ |
||||
package matcher |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
) |
||||
|
||||
func logging(module string) middleware.Middleware { |
||||
return func(handler middleware.Handler) middleware.Handler { |
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { |
||||
return module, nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
func equal(ms []middleware.Middleware, modules ...string) bool { |
||||
if len(ms) == 0 { |
||||
return false |
||||
} |
||||
for i, m := range ms { |
||||
x, _ := m(nil)(nil, nil) |
||||
if x != modules[i] { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func TestMatcher(t *testing.T) { |
||||
m := New() |
||||
m.Use(logging("logging")) |
||||
m.Add("*", logging("*")) |
||||
m.Add("/foo/*", logging("foo/*")) |
||||
m.Add("/foo/bar/*", logging("foo/bar/*")) |
||||
m.Add("/foo/bar", logging("foo/bar")) |
||||
|
||||
if ms := m.Match("/"); len(ms) != 2 { |
||||
t.Fatal("not equal") |
||||
} else if !equal(ms, "logging", "*") { |
||||
t.Fatal("not equal") |
||||
} |
||||
|
||||
if ms := m.Match("/foo/xxx"); len(ms) != 2 { |
||||
t.Fatal("not equal") |
||||
} else if !equal(ms, "logging", "foo/*") { |
||||
t.Fatal("not equal") |
||||
} |
||||
|
||||
if ms := m.Match("/foo/bar"); len(ms) != 2 { |
||||
t.Fatal("not equal") |
||||
} else if !equal(ms, "logging", "foo/bar") { |
||||
t.Fatal("not equal") |
||||
} |
||||
|
||||
if ms := m.Match("/foo/bar/x"); len(ms) != 2 { |
||||
t.Fatal("not equal") |
||||
} else if !equal(ms, "logging", "foo/bar/*") { |
||||
t.Fatal("not equal") |
||||
} |
||||
} |
Loading…
Reference in new issue