parent
3660a8d65d
commit
7977deac65
@ -0,0 +1,118 @@ |
||||
package selector |
||||
|
||||
import ( |
||||
"context" |
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
"regexp" |
||||
"strings" |
||||
) |
||||
|
||||
type ( |
||||
transporter func(ctx context.Context) (transport.Transporter, bool) |
||||
match func(operation string) bool |
||||
) |
||||
|
||||
var ( |
||||
serverTransporter transporter = func(ctx context.Context) (transport.Transporter, bool) { |
||||
return transport.FromServerContext(ctx) |
||||
} |
||||
clientTransporter transporter = func(ctx context.Context) (transport.Transporter, bool) { |
||||
return transport.FromClientContext(ctx) |
||||
} |
||||
) |
||||
|
||||
type Builder struct { |
||||
client bool |
||||
|
||||
prefix []string |
||||
regex []string |
||||
path []string |
||||
|
||||
ms []middleware.Middleware |
||||
} |
||||
|
||||
func Server(ms ...middleware.Middleware) *Builder { |
||||
return &Builder{ms: ms} |
||||
} |
||||
|
||||
func Client(ms ...middleware.Middleware) *Builder { |
||||
return &Builder{client: true, ms: ms} |
||||
} |
||||
|
||||
func (b *Builder) Prefix(prefix ...string) *Builder { |
||||
b.prefix = prefix |
||||
return b |
||||
} |
||||
|
||||
func (b *Builder) Regex(regex ...string) *Builder { |
||||
b.regex = regex |
||||
return b |
||||
} |
||||
func (b *Builder) Path(path ...string) *Builder { |
||||
b.path = path |
||||
return b |
||||
} |
||||
|
||||
func (b *Builder) Build() middleware.Middleware { |
||||
var transporter func(ctx context.Context) (transport.Transporter, bool) |
||||
if b.client { |
||||
transporter = clientTransporter |
||||
} else { |
||||
transporter = serverTransporter |
||||
} |
||||
|
||||
return selector(transporter, b.match, b.ms...) |
||||
} |
||||
|
||||
func (b *Builder) match(operation string) bool { |
||||
for _, prefix := range b.prefix { |
||||
if prefixMatch(prefix, operation) { |
||||
return true |
||||
} |
||||
} |
||||
for _, regex := range b.regex { |
||||
if regexMatch(regex, operation) { |
||||
return true |
||||
} |
||||
} |
||||
for _, path := range b.path { |
||||
if pathMatch(path, operation) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func selector(transporter transporter, match match, ms ...middleware.Middleware) middleware.Middleware { |
||||
return func(handler middleware.Handler) middleware.Handler { |
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { |
||||
|
||||
info, ok := transporter(ctx) |
||||
if !ok { |
||||
return handler(ctx, req) |
||||
} |
||||
|
||||
if !match(info.Operation()) { |
||||
return handler(ctx, req) |
||||
} |
||||
return middleware.Chain(ms...)(handler)(ctx, req) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func pathMatch(path string, operation string) bool { |
||||
return path == operation |
||||
} |
||||
|
||||
func prefixMatch(prefix string, operation string) bool { |
||||
return strings.HasPrefix(operation, prefix) |
||||
} |
||||
|
||||
func regexMatch(regex string, operation string) bool { |
||||
r, err := regexp.Compile(regex) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
return r.FindString(operation) == operation |
||||
} |
@ -0,0 +1,118 @@ |
||||
package selector |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"github.com/go-kratos/kratos/v2/middleware" |
||||
"github.com/go-kratos/kratos/v2/transport" |
||||
"testing" |
||||
) |
||||
|
||||
var ( |
||||
_ transport.Transporter = &Transport{} |
||||
) |
||||
|
||||
type Transport struct { |
||||
kind transport.Kind |
||||
endpoint string |
||||
operation string |
||||
} |
||||
|
||||
func (tr *Transport) Kind() transport.Kind { |
||||
return tr.kind |
||||
} |
||||
func (tr *Transport) Endpoint() string { |
||||
return tr.endpoint |
||||
} |
||||
func (tr *Transport) Operation() string { |
||||
return tr.operation |
||||
} |
||||
func (tr *Transport) RequestHeader() transport.Header { |
||||
return nil |
||||
} |
||||
func (tr *Transport) ReplyHeader() transport.Header { |
||||
return nil |
||||
} |
||||
|
||||
func TestMatch(t *testing.T) { |
||||
|
||||
tests := []struct { |
||||
name string |
||||
ctx context.Context |
||||
}{ |
||||
// TODO: Add test cases.
|
||||
{ |
||||
name: "/hello/world", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hello/world"}), |
||||
}, |
||||
{ |
||||
name: "/hi/world", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/hi/world"}), |
||||
}, |
||||
{ |
||||
name: "/test/1234", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/test/1234"}), |
||||
}, |
||||
{ |
||||
name: "/example/kratos", |
||||
ctx: transport.NewServerContext(context.Background(), &Transport{operation: "/example/kratos"}), |
||||
}, |
||||
} |
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
t.Log(req) |
||||
return "reply", nil |
||||
} |
||||
next = Server(testMiddleware).Prefix("/hello/").Regex(`/test/[0-9]+`). |
||||
Path("/example/kratos").Build()(next) |
||||
next(test.ctx, test.name) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestMatchClient(t *testing.T) { |
||||
|
||||
tests := []struct { |
||||
name string |
||||
ctx context.Context |
||||
}{ |
||||
// TODO: Add test cases.
|
||||
{ |
||||
name: "/hello/world", |
||||
ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/hello/world"}), |
||||
}, |
||||
{ |
||||
name: "/hi/world", |
||||
ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/hi/world"}), |
||||
}, |
||||
{ |
||||
name: "/test/1234", |
||||
ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/test/1234"}), |
||||
}, |
||||
{ |
||||
name: "/example/kratos", |
||||
ctx: transport.NewClientContext(context.Background(), &Transport{operation: "/example/kratos"}), |
||||
}, |
||||
} |
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
t.Log(req) |
||||
return "reply", nil |
||||
} |
||||
next = Client(testMiddleware).Prefix("/hello/").Regex(`/test/[0-9]+`). |
||||
Path("/example/kratos").Build()(next) |
||||
next(test.ctx, test.name) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func testMiddleware(handler middleware.Handler) middleware.Handler { |
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) { |
||||
fmt.Println("before") |
||||
reply, err = handler(ctx, req) |
||||
fmt.Println("after") |
||||
return |
||||
} |
||||
} |
Loading…
Reference in new issue