add custom binding (#2428)

* add custom binding

* fix test

* fix empty value
pull/2431/head
Tony Chen 2 years ago committed by GitHub
parent 468630cc4b
commit a680321309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      encoding/form/proto_decode.go
  2. 18
      transport/http/codec.go
  3. 6
      transport/http/context.go
  4. 20
      transport/http/context_test.go
  5. 24
      transport/http/server.go
  6. 4
      transport/http/server_test.go

@ -99,6 +99,9 @@ func getDescriptorByFieldAndName(fields protoreflect.FieldDescriptors, fieldName
}
func populateField(fd protoreflect.FieldDescriptor, v protoreflect.Message, value string) error {
if value == "" {
return nil
}
val, err := parseField(fd, value)
if err != nil {
return fmt.Errorf("parsing field %q: %w", fd.FullName().Name(), err)

@ -4,10 +4,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/httputil"
"github.com/go-kratos/kratos/v2/transport/http/binding"
"github.com/gorilla/mux"
)
// SupportPackageIsVersion1 These constants should not be referenced from any other code.
@ -37,6 +40,21 @@ type EncodeResponseFunc func(http.ResponseWriter, *http.Request, interface{}) er
// EncodeErrorFunc is encode error func.
type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error)
// DefaultRequestVars decodes the request vars to object.
func DefaultRequestVars(r *http.Request, v interface{}) error {
raws := mux.Vars(r)
vars := make(url.Values, len(raws))
for k, v := range raws {
vars[k] = []string{v}
}
return binding.BindQuery(vars, v)
}
// DefaultRequestQuery decodes the request vars to object.
func DefaultRequestQuery(r *http.Request, v interface{}) error {
return binding.BindQuery(r.URL.Query(), v)
}
// DefaultRequestDecoder decodes the request body to object.
func DefaultRequestDecoder(r *http.Request, v interface{}) error {
codec, ok := CodecForRequest(r, "Content-Type")

@ -96,9 +96,9 @@ func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler {
}
return middleware.Chain(c.router.srv.middleware.Match(c.req.URL.Path)...)(h)
}
func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) }
func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) }
func (c *wrapper) BindQuery(v interface{}) error { return binding.BindQuery(c.Query(), v) }
func (c *wrapper) Bind(v interface{}) error { return c.router.srv.decBody(c.req, v) }
func (c *wrapper) BindVars(v interface{}) error { return c.router.srv.decVars(c.req, v) }
func (c *wrapper) BindQuery(v interface{}) error { return c.router.srv.decQuery(c.req, v) }
func (c *wrapper) BindForm(v interface{}) error { return binding.BindForm(c.req, v) }
func (c *wrapper) Returns(v interface{}, err error) error {
if err != nil {

@ -12,9 +12,11 @@ import (
"time"
)
var testRouter = &Router{srv: NewServer()}
func TestContextHeader(t *testing.T) {
w := wrapper{
router: nil,
router: testRouter,
req: &http.Request{Header: map[string][]string{"name": {"kratos"}}},
res: nil,
w: responseWriter{},
@ -27,7 +29,7 @@ func TestContextHeader(t *testing.T) {
func TestContextForm(t *testing.T) {
w := wrapper{
router: nil,
router: testRouter,
req: &http.Request{Header: map[string][]string{"name": {"kratos"}}, Method: "POST"},
res: nil,
w: responseWriter{},
@ -38,7 +40,7 @@ func TestContextForm(t *testing.T) {
}
w = wrapper{
router: nil,
router: testRouter,
req: &http.Request{Form: map[string][]string{"name": {"kratos"}}},
res: nil,
w: responseWriter{},
@ -51,7 +53,7 @@ func TestContextForm(t *testing.T) {
func TestContextQuery(t *testing.T) {
w := wrapper{
router: nil,
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "github.com", Path: "go-kratos/kratos", RawQuery: "page=1"}, Method: "POST"},
res: nil,
w: responseWriter{},
@ -65,7 +67,7 @@ func TestContextQuery(t *testing.T) {
func TestContextRequest(t *testing.T) {
req := &http.Request{Method: "POST"}
w := wrapper{
router: nil,
router: testRouter,
req: req,
res: nil,
w: responseWriter{},
@ -100,7 +102,7 @@ func TestContextResponse(t *testing.T) {
func TestContextBindQuery(t *testing.T) {
w := wrapper{
router: nil,
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev", RawQuery: "page=2"}},
res: nil,
w: responseWriter{},
@ -120,7 +122,7 @@ func TestContextBindQuery(t *testing.T) {
func TestContextBindForm(t *testing.T) {
w := wrapper{
router: nil,
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev"}, Form: map[string][]string{"page": {"2"}}},
res: nil,
w: responseWriter{},
@ -141,7 +143,7 @@ func TestContextBindForm(t *testing.T) {
func TestContextResponseReturn(t *testing.T) {
writer := httptest.NewRecorder()
w := wrapper{
router: nil,
router: testRouter,
req: nil,
res: writer,
w: responseWriter{},
@ -174,7 +176,7 @@ func TestContextCtx(t *testing.T) {
req := &http.Request{Method: "POST"}
req = req.WithContext(ctx)
w := wrapper{
router: &Router{srv: &Server{enc: DefaultResponseEncoder}},
router: testRouter,
req: req,
res: nil,
w: responseWriter{},

@ -70,10 +70,24 @@ func Filter(filters ...FilterFunc) ServerOption {
}
}
// RequestVarsDecoder with request decoder.
func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decVars = dec
}
}
// RequestQueryDecoder with request decoder.
func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decQuery = dec
}
}
// RequestDecoder with request decoder.
func RequestDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.dec = dec
o.decBody = dec
}
}
@ -126,7 +140,9 @@ type Server struct {
timeout time.Duration
filters []FilterFunc
middleware matcher.Matcher
dec DecodeRequestFunc
decVars DecodeRequestFunc
decQuery DecodeRequestFunc
decBody DecodeRequestFunc
enc EncodeResponseFunc
ene EncodeErrorFunc
strictSlash bool
@ -140,7 +156,9 @@ func NewServer(opts ...ServerOption) *Server {
address: ":0",
timeout: 1 * time.Second,
middleware: matcher.New(),
dec: DefaultRequestDecoder,
decVars: DefaultRequestVars,
decQuery: DefaultRequestQuery,
decBody: DefaultRequestDecoder,
enc: DefaultResponseEncoder,
ene: DefaultErrorEncoder,
strictSlash: true,

@ -316,8 +316,8 @@ func TestRequestDecoder(t *testing.T) {
o := &Server{}
v := func(*http.Request, interface{}) error { return nil }
RequestDecoder(v)(o)
if o.dec == nil {
t.Errorf("expected nil got %v", o.dec)
if o.decBody == nil {
t.Errorf("expected nil got %v", o.decBody)
}
}

Loading…
Cancel
Save