diff --git a/pkg/net/http/blademaster/binding/binding.go b/pkg/net/http/blademaster/binding/binding.go new file mode 100644 index 000000000..46f1c770d --- /dev/null +++ b/pkg/net/http/blademaster/binding/binding.go @@ -0,0 +1,85 @@ +package binding + +import ( + "net/http" + "strings" + + "gopkg.in/go-playground/validator.v9" +) + +// MIME +const ( + MIMEJSON = "application/json" + MIMEHTML = "text/html" + MIMEXML = "application/xml" + MIMEXML2 = "text/xml" + MIMEPlain = "text/plain" + MIMEPOSTForm = "application/x-www-form-urlencoded" + MIMEMultipartPOSTForm = "multipart/form-data" +) + +// Binding http binding request interface. +type Binding interface { + Name() string + Bind(*http.Request, interface{}) error +} + +// StructValidator http validator interface. +type StructValidator interface { + // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. + // If the received type is not a struct, any validation should be skipped and nil must be returned. + // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the struct is not valid or the validation itself fails, a descriptive error should be returned. + // Otherwise nil must be returned. + ValidateStruct(interface{}) error + + // RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key + // NOTE: if the key already exists, the previous validation function will be replaced. + // NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation + RegisterValidation(string, validator.Func) error +} + +// Validator default validator. +var Validator StructValidator = &defaultValidator{} + +// Binding +var ( + JSON = jsonBinding{} + XML = xmlBinding{} + Form = formBinding{} + Query = queryBinding{} + FormPost = formPostBinding{} + FormMultipart = formMultipartBinding{} +) + +// Default get by binding type by method and contexttype. +func Default(method, contentType string) Binding { + if method == "GET" { + return Form + } + + contentType = stripContentTypeParam(contentType) + switch contentType { + case MIMEJSON: + return JSON + case MIMEXML, MIMEXML2: + return XML + default: //case MIMEPOSTForm, MIMEMultipartPOSTForm: + return Form + } +} + +func validate(obj interface{}) error { + if Validator == nil { + return nil + } + return Validator.ValidateStruct(obj) +} + +func stripContentTypeParam(contentType string) string { + i := strings.Index(contentType, ";") + if i != -1 { + contentType = contentType[:i] + } + return contentType +} diff --git a/pkg/net/http/blademaster/binding/binding_test.go b/pkg/net/http/blademaster/binding/binding_test.go new file mode 100644 index 000000000..8a5da8a6e --- /dev/null +++ b/pkg/net/http/blademaster/binding/binding_test.go @@ -0,0 +1,342 @@ +package binding + +import ( + "bytes" + "mime/multipart" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +type FooStruct struct { + Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"` +} + +type FooBarStruct struct { + FooStruct + Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"` + Slice []string `form:"slice" validate:"max=10"` +} + +type ComplexDefaultStruct struct { + Int int `form:"int" default:"999"` + String string `form:"string" default:"default-string"` + Bool bool `form:"bool" default:"false"` + Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"` + Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"` +} + +type Int8SliceStruct struct { + State []int8 `form:"state,split"` +} + +type Int64SliceStruct struct { + State []int64 `form:"state,split"` +} + +type StringSliceStruct struct { + State []string `form:"state,split"` +} + +func TestBindingDefault(t *testing.T) { + assert.Equal(t, Default("GET", ""), Form) + assert.Equal(t, Default("GET", MIMEJSON), Form) + assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form) + + assert.Equal(t, Default("POST", MIMEJSON), JSON) + assert.Equal(t, Default("PUT", MIMEJSON), JSON) + + assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON) + assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON) + + assert.Equal(t, Default("POST", MIMEXML), XML) + assert.Equal(t, Default("PUT", MIMEXML2), XML) + + assert.Equal(t, Default("POST", MIMEPOSTForm), Form) + assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) + + assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form) + assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form) + + assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) + assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) + +} + +func TestStripContentType(t *testing.T) { + c1 := "application/vnd.mozilla.xul+xml" + c2 := "application/vnd.mozilla.xul+xml; charset=utf-8" + assert.Equal(t, stripContentTypeParam(c1), c1) + assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml") +} + +func TestBindInt8Form(t *testing.T) { + params := "state=1,2,3" + req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q := new(Int8SliceStruct) + Form.Bind(req, q) + assert.EqualValues(t, []int8{1, 2, 3}, q.State) + + params = "state=1,2,3,256" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(Int8SliceStruct) + assert.Error(t, Form.Bind(req, q)) + + params = "state=" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(Int8SliceStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Len(t, q.State, 0) + + params = "state=1,,2" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(Int8SliceStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.EqualValues(t, []int8{1, 2}, q.State) +} + +func TestBindInt64Form(t *testing.T) { + params := "state=1,2,3" + req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q := new(Int64SliceStruct) + Form.Bind(req, q) + assert.EqualValues(t, []int64{1, 2, 3}, q.State) + + params = "state=" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(Int64SliceStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Len(t, q.State, 0) +} + +func TestBindStringForm(t *testing.T) { + params := "state=1,2,3" + req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q := new(StringSliceStruct) + Form.Bind(req, q) + assert.EqualValues(t, []string{"1", "2", "3"}, q.State) + + params = "state=" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(StringSliceStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Len(t, q.State, 0) + + params = "state=p,,p" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(StringSliceStruct) + Form.Bind(req, q) + assert.EqualValues(t, []string{"p", "p"}, q.State) +} + +func TestBindingJSON(t *testing.T) { + testBodyBinding(t, + JSON, "json", + "/", "/", + `{"foo": "bar"}`, `{"bar": "foo"}`) +} + +func TestBindingForm(t *testing.T) { + testFormBinding(t, "POST", + "/", "/", + "foo=bar&bar=foo&slice=a&slice=b", "bar2=foo") +} + +func TestBindingForm2(t *testing.T) { + testFormBinding(t, "GET", + "/?foo=bar&bar=foo", "/?bar2=foo", + "", "") +} + +func TestBindingQuery(t *testing.T) { + testQueryBinding(t, "POST", + "/?foo=bar&bar=foo", "/", + "foo=unused", "bar2=foo") +} + +func TestBindingQuery2(t *testing.T) { + testQueryBinding(t, "GET", + "/?foo=bar&bar=foo", "/?bar2=foo", + "foo=unused", "") +} + +func TestBindingXML(t *testing.T) { + testBodyBinding(t, + XML, "xml", + "/", "/", + "bar", "foo") +} + +func createFormPostRequest() *http.Request { + req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) + req.Header.Set("Content-Type", MIMEPOSTForm) + return req +} + +func createFormMultipartRequest() *http.Request { + boundary := "--testboundary" + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + defer mw.Close() + + mw.SetBoundary(boundary) + mw.WriteField("foo", "bar") + mw.WriteField("bar", "foo") + req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) + req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) + return req +} + +func TestBindingFormPost(t *testing.T) { + req := createFormPostRequest() + var obj FooBarStruct + FormPost.Bind(req, &obj) + + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") +} + +func TestBindingFormMultipart(t *testing.T) { + req := createFormMultipartRequest() + var obj FooBarStruct + FormMultipart.Bind(req, &obj) + + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") +} + +func TestValidationFails(t *testing.T) { + var obj FooStruct + req := requestWithBody("POST", "/", `{"bar": "foo"}`) + err := JSON.Bind(req, &obj) + assert.Error(t, err) +} + +func TestValidationDisabled(t *testing.T) { + backup := Validator + Validator = nil + defer func() { Validator = backup }() + + var obj FooStruct + req := requestWithBody("POST", "/", `{"bar": "foo"}`) + err := JSON.Bind(req, &obj) + assert.NoError(t, err) +} + +func TestExistsSucceeds(t *testing.T) { + type HogeStruct struct { + Hoge *int `json:"hoge" binding:"exists"` + } + + var obj HogeStruct + req := requestWithBody("POST", "/", `{"hoge": 0}`) + err := JSON.Bind(req, &obj) + assert.NoError(t, err) +} + +func TestFormDefaultValue(t *testing.T) { + params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8" + req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q := new(ComplexDefaultStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Equal(t, 333, q.Int) + assert.Equal(t, "hello", q.String) + assert.Equal(t, true, q.Bool) + assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice) + assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice) + + params = "string=hello&bool=false" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(ComplexDefaultStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Equal(t, 999, q.Int) + assert.Equal(t, "hello", q.String) + assert.Equal(t, false, q.Bool) + assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) + assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) + + params = "strings=hello" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(ComplexDefaultStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Equal(t, 999, q.Int) + assert.Equal(t, "default-string", q.String) + assert.Equal(t, false, q.Bool) + assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) + assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) + + params = "int=&string=&bool=true&int64_slice=&int8_slice=" + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + q = new(ComplexDefaultStruct) + assert.NoError(t, Form.Bind(req, q)) + assert.Equal(t, 999, q.Int) + assert.Equal(t, "default-string", q.String) + assert.Equal(t, true, q.Bool) + assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) + assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) +} + +func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) { + b := Form + assert.Equal(t, b.Name(), "form") + + obj := FooBarStruct{} + req := requestWithBody(method, path, body) + if method == "POST" { + req.Header.Add("Content-Type", MIMEPOSTForm) + } + err := b.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") + + obj = FooBarStruct{} + req = requestWithBody(method, badPath, badBody) + err = JSON.Bind(req, &obj) + assert.Error(t, err) +} + +func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) { + b := Query + assert.Equal(t, b.Name(), "query") + + obj := FooBarStruct{} + req := requestWithBody(method, path, body) + if method == "POST" { + req.Header.Add("Content-Type", MIMEPOSTForm) + } + err := b.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, obj.Foo, "bar") + assert.Equal(t, obj.Bar, "foo") +} + +func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { + assert.Equal(t, b.Name(), name) + + obj := FooStruct{} + req := requestWithBody("POST", path, body) + err := b.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, obj.Foo, "bar") + + obj = FooStruct{} + req = requestWithBody("POST", badPath, badBody) + err = JSON.Bind(req, &obj) + assert.Error(t, err) +} + +func requestWithBody(method, path, body string) (req *http.Request) { + req, _ = http.NewRequest(method, path, bytes.NewBufferString(body)) + return +} +func BenchmarkBindingForm(b *testing.B) { + req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w") + req.Header.Add("Content-Type", MIMEPOSTForm) + f := Form + for i := 0; i < b.N; i++ { + obj := FooBarStruct{} + f.Bind(req, &obj) + } +} diff --git a/pkg/net/http/blademaster/binding/default_validator.go b/pkg/net/http/blademaster/binding/default_validator.go new file mode 100644 index 000000000..5dbf67ed2 --- /dev/null +++ b/pkg/net/http/blademaster/binding/default_validator.go @@ -0,0 +1,45 @@ +package binding + +import ( + "reflect" + "sync" + + "gopkg.in/go-playground/validator.v9" +) + +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +var _ StructValidator = &defaultValidator{} + +func (v *defaultValidator) ValidateStruct(obj interface{}) error { + if kindOfData(obj) == reflect.Struct { + v.lazyinit() + if err := v.validate.Struct(obj); err != nil { + return err + } + } + return nil +} + +func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error { + v.lazyinit() + return v.validate.RegisterValidation(key, fn) +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + }) +} + +func kindOfData(data interface{}) reflect.Kind { + value := reflect.ValueOf(data) + valueType := value.Kind() + if valueType == reflect.Ptr { + valueType = value.Elem().Kind() + } + return valueType +} diff --git a/pkg/net/http/blademaster/binding/example/test.pb.go b/pkg/net/http/blademaster/binding/example/test.pb.go new file mode 100644 index 000000000..3de8444ff --- /dev/null +++ b/pkg/net/http/blademaster/binding/example/test.pb.go @@ -0,0 +1,113 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package example is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + Test +*/ +package example + +import proto "github.com/golang/protobuf/proto" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +type FOO int32 + +const ( + FOO_X FOO = 17 +) + +var FOO_name = map[int32]string{ + 17: "X", +} +var FOO_value = map[string]int32{ + "X": 17, +} + +func (x FOO) Enum() *FOO { + p := new(FOO) + *p = x + return p +} +func (x FOO) String() string { + return proto.EnumName(FOO_name, int32(x)) +} +func (x *FOO) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") + if err != nil { + return err + } + *x = FOO(value) + return nil +} + +type Test struct { + Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` + Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` + Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` + Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Test) Reset() { *m = Test{} } +func (m *Test) String() string { return proto.CompactTextString(m) } +func (*Test) ProtoMessage() {} + +const Default_Test_Type int32 = 77 + +func (m *Test) GetLabel() string { + if m != nil && m.Label != nil { + return *m.Label + } + return "" +} + +func (m *Test) GetType() int32 { + if m != nil && m.Type != nil { + return *m.Type + } + return Default_Test_Type +} + +func (m *Test) GetReps() []int64 { + if m != nil { + return m.Reps + } + return nil +} + +func (m *Test) GetOptionalgroup() *Test_OptionalGroup { + if m != nil { + return m.Optionalgroup + } + return nil +} + +type Test_OptionalGroup struct { + RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } +func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } +func (*Test_OptionalGroup) ProtoMessage() {} + +func (m *Test_OptionalGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" +} + +func init() { + proto.RegisterEnum("example.FOO", FOO_name, FOO_value) +} diff --git a/pkg/net/http/blademaster/binding/example/test.proto b/pkg/net/http/blademaster/binding/example/test.proto new file mode 100644 index 000000000..8ee9800aa --- /dev/null +++ b/pkg/net/http/blademaster/binding/example/test.proto @@ -0,0 +1,12 @@ +package example; + +enum FOO {X=17;}; + +message Test { + required string label = 1; + optional int32 type = 2[default=77]; + repeated int64 reps = 3; + optional group OptionalGroup = 4{ + required string RequiredField = 5; + } +} diff --git a/pkg/net/http/blademaster/binding/example_test.go b/pkg/net/http/blademaster/binding/example_test.go new file mode 100644 index 000000000..c667a517c --- /dev/null +++ b/pkg/net/http/blademaster/binding/example_test.go @@ -0,0 +1,36 @@ +package binding + +import ( + "fmt" + "log" + "net/http" +) + +type Arg struct { + Max int64 `form:"max" validate:"max=10"` + Min int64 `form:"min" validate:"min=2"` + Range int64 `form:"range" validate:"min=1,max=10"` + // use split option to split arg 1,2,3 into slice [1 2 3] + // otherwise slice type with parse url.Values (eg:a=b&a=c) default. + Slice []int64 `form:"slice,split" validate:"min=1"` +} + +func ExampleBinding() { + req := initHTTP("max=9&min=3&range=3&slice=1,2,3") + arg := new(Arg) + if err := Form.Bind(req, arg); err != nil { + log.Fatal(err) + } + fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice) + // Output: + // arg.Max 9 + // arg.Min 3 + // arg.Range 3 + // arg.Slice [1 2 3] +} + +func initHTTP(params string) (req *http.Request) { + req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) + req.ParseForm() + return +} diff --git a/pkg/net/http/blademaster/binding/form.go b/pkg/net/http/blademaster/binding/form.go new file mode 100644 index 000000000..61aa5ee83 --- /dev/null +++ b/pkg/net/http/blademaster/binding/form.go @@ -0,0 +1,55 @@ +package binding + +import ( + "net/http" + + "github.com/pkg/errors" +) + +const defaultMemory = 32 * 1024 * 1024 + +type formBinding struct{} +type formPostBinding struct{} +type formMultipartBinding struct{} + +func (f formBinding) Name() string { + return "form" +} + +func (f formBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseForm(); err != nil { + return errors.WithStack(err) + } + if err := mapForm(obj, req.Form); err != nil { + return err + } + return validate(obj) +} + +func (f formPostBinding) Name() string { + return "form-urlencoded" +} + +func (f formPostBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseForm(); err != nil { + return errors.WithStack(err) + } + if err := mapForm(obj, req.PostForm); err != nil { + return err + } + return validate(obj) +} + +func (f formMultipartBinding) Name() string { + return "multipart/form-data" +} + +func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error { + if err := req.ParseMultipartForm(defaultMemory); err != nil { + return errors.WithStack(err) + } + if err := mapForm(obj, req.MultipartForm.Value); err != nil { + return err + } + return validate(obj) +} diff --git a/pkg/net/http/blademaster/binding/form_mapping.go b/pkg/net/http/blademaster/binding/form_mapping.go new file mode 100644 index 000000000..ac4ecd116 --- /dev/null +++ b/pkg/net/http/blademaster/binding/form_mapping.go @@ -0,0 +1,276 @@ +package binding + +import ( + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/pkg/errors" +) + +// scache struct reflect type cache. +var scache = &cache{ + data: make(map[reflect.Type]*sinfo), +} + +type cache struct { + data map[reflect.Type]*sinfo + mutex sync.RWMutex +} + +func (c *cache) get(obj reflect.Type) (s *sinfo) { + var ok bool + c.mutex.RLock() + if s, ok = c.data[obj]; !ok { + c.mutex.RUnlock() + s = c.set(obj) + return + } + c.mutex.RUnlock() + return +} + +func (c *cache) set(obj reflect.Type) (s *sinfo) { + s = new(sinfo) + tp := obj.Elem() + for i := 0; i < tp.NumField(); i++ { + fd := new(field) + fd.tp = tp.Field(i) + tag := fd.tp.Tag.Get("form") + fd.name, fd.option = parseTag(tag) + if defV := fd.tp.Tag.Get("default"); defV != "" { + dv := reflect.New(fd.tp.Type).Elem() + setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option) + fd.hasDefault = true + fd.defaultValue = dv + } + s.field = append(s.field, fd) + } + c.mutex.Lock() + c.data[obj] = s + c.mutex.Unlock() + return +} + +type sinfo struct { + field []*field +} + +type field struct { + tp reflect.StructField + name string + option tagOptions + + hasDefault bool // if field had default value + defaultValue reflect.Value // field default value +} + +func mapForm(ptr interface{}, form map[string][]string) error { + sinfo := scache.get(reflect.TypeOf(ptr)) + val := reflect.ValueOf(ptr).Elem() + for i, fd := range sinfo.field { + typeField := fd.tp + structField := val.Field(i) + if !structField.CanSet() { + continue + } + + structFieldKind := structField.Kind() + inputFieldName := fd.name + if inputFieldName == "" { + inputFieldName = typeField.Name + + // if "form" tag is nil, we inspect if the field is a struct. + // this would not make sense for JSON parsing but it does for a form + // since data is flatten + if structFieldKind == reflect.Struct { + err := mapForm(structField.Addr().Interface(), form) + if err != nil { + return err + } + continue + } + } + inputValue, exists := form[inputFieldName] + if !exists { + // Set the field as default value when the input value is not exist + if fd.hasDefault { + structField.Set(fd.defaultValue) + } + continue + } + // Set the field as default value when the input value is empty + if fd.hasDefault && inputValue[0] == "" { + structField.Set(fd.defaultValue) + continue + } + if _, isTime := structField.Interface().(time.Time); isTime { + if err := setTimeField(inputValue[0], typeField, structField); err != nil { + return err + } + continue + } + if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil { + return err + } + } + return nil +} + +func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error { + switch valueKind { + case reflect.Int: + return setIntField(val[0], 0, structField) + case reflect.Int8: + return setIntField(val[0], 8, structField) + case reflect.Int16: + return setIntField(val[0], 16, structField) + case reflect.Int32: + return setIntField(val[0], 32, structField) + case reflect.Int64: + return setIntField(val[0], 64, structField) + case reflect.Uint: + return setUintField(val[0], 0, structField) + case reflect.Uint8: + return setUintField(val[0], 8, structField) + case reflect.Uint16: + return setUintField(val[0], 16, structField) + case reflect.Uint32: + return setUintField(val[0], 32, structField) + case reflect.Uint64: + return setUintField(val[0], 64, structField) + case reflect.Bool: + return setBoolField(val[0], structField) + case reflect.Float32: + return setFloatField(val[0], 32, structField) + case reflect.Float64: + return setFloatField(val[0], 64, structField) + case reflect.String: + structField.SetString(val[0]) + case reflect.Slice: + if option.Contains("split") { + val = strings.Split(val[0], ",") + } + filtered := filterEmpty(val) + switch structField.Type().Elem().Kind() { + case reflect.Int64: + valSli := make([]int64, 0, len(filtered)) + for i := 0; i < len(filtered); i++ { + d, err := strconv.ParseInt(filtered[i], 10, 64) + if err != nil { + return err + } + valSli = append(valSli, d) + } + structField.Set(reflect.ValueOf(valSli)) + case reflect.String: + valSli := make([]string, 0, len(filtered)) + for i := 0; i < len(filtered); i++ { + valSli = append(valSli, filtered[i]) + } + structField.Set(reflect.ValueOf(valSli)) + default: + sliceOf := structField.Type().Elem().Kind() + numElems := len(filtered) + slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered)) + for i := 0; i < numElems; i++ { + if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil { + return err + } + } + structField.Set(slice) + } + default: + return errors.New("Unknown type") + } + return nil +} + +func setIntField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return errors.WithStack(err) +} + +func setUintField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return errors.WithStack(err) +} + +func setBoolField(val string, field reflect.Value) error { + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err == nil { + field.SetBool(boolVal) + } + return nil +} + +func setFloatField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return errors.WithStack(err) +} + +func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { + timeFormat := structField.Tag.Get("time_format") + if timeFormat == "" { + return errors.New("Blank time format") + } + + if val == "" { + value.Set(reflect.ValueOf(time.Time{})) + return nil + } + + l := time.Local + if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { + l = time.UTC + } + + if locTag := structField.Tag.Get("time_location"); locTag != "" { + loc, err := time.LoadLocation(locTag) + if err != nil { + return errors.WithStack(err) + } + l = loc + } + + t, err := time.ParseInLocation(timeFormat, val, l) + if err != nil { + return errors.WithStack(err) + } + + value.Set(reflect.ValueOf(t)) + return nil +} + +func filterEmpty(val []string) []string { + filtered := make([]string, 0, len(val)) + for _, v := range val { + if v != "" { + filtered = append(filtered, v) + } + } + return filtered +} diff --git a/pkg/net/http/blademaster/binding/json.go b/pkg/net/http/blademaster/binding/json.go new file mode 100644 index 000000000..f01e479b3 --- /dev/null +++ b/pkg/net/http/blademaster/binding/json.go @@ -0,0 +1,22 @@ +package binding + +import ( + "encoding/json" + "net/http" + + "github.com/pkg/errors" +) + +type jsonBinding struct{} + +func (jsonBinding) Name() string { + return "json" +} + +func (jsonBinding) Bind(req *http.Request, obj interface{}) error { + decoder := json.NewDecoder(req.Body) + if err := decoder.Decode(obj); err != nil { + return errors.WithStack(err) + } + return validate(obj) +} diff --git a/pkg/net/http/blademaster/binding/query.go b/pkg/net/http/blademaster/binding/query.go new file mode 100644 index 000000000..b169436eb --- /dev/null +++ b/pkg/net/http/blademaster/binding/query.go @@ -0,0 +1,19 @@ +package binding + +import ( + "net/http" +) + +type queryBinding struct{} + +func (queryBinding) Name() string { + return "query" +} + +func (queryBinding) Bind(req *http.Request, obj interface{}) error { + values := req.URL.Query() + if err := mapForm(obj, values); err != nil { + return err + } + return validate(obj) +} diff --git a/pkg/net/http/blademaster/binding/tags.go b/pkg/net/http/blademaster/binding/tags.go new file mode 100644 index 000000000..535bd8624 --- /dev/null +++ b/pkg/net/http/blademaster/binding/tags.go @@ -0,0 +1,44 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binding + +import ( + "strings" +) + +// tagOptions is the string following a comma in a struct field's "json" +// tag, or the empty string. It does not include the leading comma. +type tagOptions string + +// parseTag splits a struct field's json tag into its name and +// comma-separated options. +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + return tag, tagOptions("") +} + +// Contains reports whether a comma-separated list of options +// contains a particular substr flag. substr must be surrounded by a +// string boundary or commas. +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + s := string(o) + for s != "" { + var next string + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + if s == optionName { + return true + } + s = next + } + return false +} diff --git a/pkg/net/http/blademaster/binding/validate_test.go b/pkg/net/http/blademaster/binding/validate_test.go new file mode 100644 index 000000000..ac1793713 --- /dev/null +++ b/pkg/net/http/blademaster/binding/validate_test.go @@ -0,0 +1,209 @@ +package binding + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type testInterface interface { + String() string +} + +type substructNoValidation struct { + IString string + IInt int +} + +type mapNoValidationSub map[string]substructNoValidation + +type structNoValidationValues struct { + substructNoValidation + + Boolean bool + + Uinteger uint + Integer int + Integer8 int8 + Integer16 int16 + Integer32 int32 + Integer64 int64 + Uinteger8 uint8 + Uinteger16 uint16 + Uinteger32 uint32 + Uinteger64 uint64 + + Float32 float32 + Float64 float64 + + String string + + Date time.Time + + Struct substructNoValidation + InlinedStruct struct { + String []string + Integer int + } + + IntSlice []int + IntPointerSlice []*int + StructPointerSlice []*substructNoValidation + StructSlice []substructNoValidation + InterfaceSlice []testInterface + + UniversalInterface interface{} + CustomInterface testInterface + + FloatMap map[string]float32 + StructMap mapNoValidationSub +} + +func createNoValidationValues() structNoValidationValues { + integer := 1 + s := structNoValidationValues{ + Boolean: true, + Uinteger: 1 << 29, + Integer: -10000, + Integer8: 120, + Integer16: -20000, + Integer32: 1 << 29, + Integer64: 1 << 61, + Uinteger8: 250, + Uinteger16: 50000, + Uinteger32: 1 << 31, + Uinteger64: 1 << 62, + Float32: 123.456, + Float64: 123.456789, + String: "text", + Date: time.Time{}, + CustomInterface: &bytes.Buffer{}, + Struct: substructNoValidation{}, + IntSlice: []int{-3, -2, 1, 0, 1, 2, 3}, + IntPointerSlice: []*int{&integer}, + StructSlice: []substructNoValidation{}, + UniversalInterface: 1.2, + FloatMap: map[string]float32{ + "foo": 1.23, + "bar": 232.323, + }, + StructMap: mapNoValidationSub{ + "foo": substructNoValidation{}, + "bar": substructNoValidation{}, + }, + // StructPointerSlice []noValidationSub + // InterfaceSlice []testInterface + } + s.InlinedStruct.Integer = 1000 + s.InlinedStruct.String = []string{"first", "second"} + s.IString = "substring" + s.IInt = 987654 + return s +} + +func TestValidateNoValidationValues(t *testing.T) { + origin := createNoValidationValues() + test := createNoValidationValues() + empty := structNoValidationValues{} + + assert.Nil(t, validate(test)) + assert.Nil(t, validate(&test)) + assert.Nil(t, validate(empty)) + assert.Nil(t, validate(&empty)) + + assert.Equal(t, origin, test) +} + +type structNoValidationPointer struct { + // substructNoValidation + + Boolean bool + + Uinteger *uint + Integer *int + Integer8 *int8 + Integer16 *int16 + Integer32 *int32 + Integer64 *int64 + Uinteger8 *uint8 + Uinteger16 *uint16 + Uinteger32 *uint32 + Uinteger64 *uint64 + + Float32 *float32 + Float64 *float64 + + String *string + + Date *time.Time + + Struct *substructNoValidation + + IntSlice *[]int + IntPointerSlice *[]*int + StructPointerSlice *[]*substructNoValidation + StructSlice *[]substructNoValidation + InterfaceSlice *[]testInterface + + FloatMap *map[string]float32 + StructMap *mapNoValidationSub +} + +func TestValidateNoValidationPointers(t *testing.T) { + //origin := createNoValidation_values() + //test := createNoValidation_values() + empty := structNoValidationPointer{} + + //assert.Nil(t, validate(test)) + //assert.Nil(t, validate(&test)) + assert.Nil(t, validate(empty)) + assert.Nil(t, validate(&empty)) + + //assert.Equal(t, origin, test) +} + +type Object map[string]interface{} + +func TestValidatePrimitives(t *testing.T) { + obj := Object{"foo": "bar", "bar": 1} + assert.NoError(t, validate(obj)) + assert.NoError(t, validate(&obj)) + assert.Equal(t, obj, Object{"foo": "bar", "bar": 1}) + + obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} + assert.NoError(t, validate(obj2)) + assert.NoError(t, validate(&obj2)) + + nu := 10 + assert.NoError(t, validate(nu)) + assert.NoError(t, validate(&nu)) + assert.Equal(t, nu, 10) + + str := "value" + assert.NoError(t, validate(str)) + assert.NoError(t, validate(&str)) + assert.Equal(t, str, "value") +} + +// structCustomValidation is a helper struct we use to check that +// custom validation can be registered on it. +// The `notone` binding directive is for custom validation and registered later. +// type structCustomValidation struct { +// Integer int `binding:"notone"` +// } + +// notOne is a custom validator meant to be used with `validator.v8` library. +// The method signature for `v9` is significantly different and this function +// would need to be changed for tests to pass after upgrade. +// See https://github.com/gin-gonic/gin/pull/1015. +// func notOne( +// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value, +// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string, +// ) bool { +// if val, ok := field.Interface().(int); ok { +// return val != 1 +// } +// return false +// } diff --git a/pkg/net/http/blademaster/binding/xml.go b/pkg/net/http/blademaster/binding/xml.go new file mode 100644 index 000000000..99b303c2f --- /dev/null +++ b/pkg/net/http/blademaster/binding/xml.go @@ -0,0 +1,22 @@ +package binding + +import ( + "encoding/xml" + "net/http" + + "github.com/pkg/errors" +) + +type xmlBinding struct{} + +func (xmlBinding) Name() string { + return "xml" +} + +func (xmlBinding) Bind(req *http.Request, obj interface{}) error { + decoder := xml.NewDecoder(req.Body) + if err := decoder.Decode(obj); err != nil { + return errors.WithStack(err) + } + return validate(obj) +} diff --git a/pkg/net/http/blademaster/context.go b/pkg/net/http/blademaster/context.go new file mode 100644 index 000000000..59154435c --- /dev/null +++ b/pkg/net/http/blademaster/context.go @@ -0,0 +1,306 @@ +package blademaster + +import ( + "context" + "math" + "net/http" + "strconv" + + "github.com/bilibili/Kratos/pkg/ecode" + "github.com/bilibili/Kratos/pkg/net/http/blademaster/binding" + "github.com/bilibili/Kratos/pkg/net/http/blademaster/render" + + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/types" + "github.com/pkg/errors" +) + +const ( + _abortIndex int8 = math.MaxInt8 / 2 +) + +var ( + _openParen = []byte("(") + _closeParen = []byte(")") +) + +// Context is the most important part. It allows us to pass variables between +// middleware, manage the flow, validate the JSON of a request and render a +// JSON response for example. +type Context struct { + context.Context + + Request *http.Request + Writer http.ResponseWriter + + // flow control + index int8 + handlers []HandlerFunc + + // Keys is a key/value pair exclusively for the context of each request. + Keys map[string]interface{} + + Error error + + method string + engine *Engine +} + +/************************************/ +/*********** FLOW CONTROL ***********/ +/************************************/ + +// Next should be used only inside middleware. +// It executes the pending handlers in the chain inside the calling handler. +// See example in godoc. +func (c *Context) Next() { + c.index++ + s := int8(len(c.handlers)) + for ; c.index < s; c.index++ { + // only check method on last handler, otherwise middlewares + // will never be effected if request method is not matched + if c.index == s-1 && c.method != c.Request.Method { + code := http.StatusMethodNotAllowed + c.Error = ecode.MethodNotAllowed + http.Error(c.Writer, http.StatusText(code), code) + return + } + + c.handlers[c.index](c) + } +} + +// Abort prevents pending handlers from being called. Note that this will not stop the current handler. +// Let's say you have an authorization middleware that validates that the current request is authorized. +// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers +// for this request are not called. +func (c *Context) Abort() { + c.index = _abortIndex +} + +// AbortWithStatus calls `Abort()` and writes the headers with the specified status code. +// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401). +func (c *Context) AbortWithStatus(code int) { + c.Status(code) + c.Abort() +} + +// IsAborted returns true if the current context was aborted. +func (c *Context) IsAborted() bool { + return c.index >= _abortIndex +} + +/************************************/ +/******** METADATA MANAGEMENT********/ +/************************************/ + +// Set is used to store a new key/value pair exclusively for this context. +// It also lazy initializes c.Keys if it was not used previously. +func (c *Context) Set(key string, value interface{}) { + if c.Keys == nil { + c.Keys = make(map[string]interface{}) + } + c.Keys[key] = value +} + +// Get returns the value for the given key, ie: (value, true). +// If the value does not exists it returns (nil, false) +func (c *Context) Get(key string) (value interface{}, exists bool) { + value, exists = c.Keys[key] + return +} + +/************************************/ +/******** RESPONSE RENDERING ********/ +/************************************/ + +// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + +// Status sets the HTTP response code. +func (c *Context) Status(code int) { + c.Writer.WriteHeader(code) +} + +// Render http response with http code by a render instance. +func (c *Context) Render(code int, r render.Render) { + r.WriteContentType(c.Writer) + if code > 0 { + c.Status(code) + } + + if !bodyAllowedForStatus(code) { + return + } + + params := c.Request.Form + + cb := params.Get("callback") + jsonp := cb != "" && params.Get("jsonp") == "jsonp" + if jsonp { + c.Writer.Write([]byte(cb)) + c.Writer.Write(_openParen) + } + + if err := r.Render(c.Writer); err != nil { + c.Error = err + return + } + + if jsonp { + if _, err := c.Writer.Write(_closeParen); err != nil { + c.Error = errors.WithStack(err) + } + } +} + +// JSON serializes the given struct as JSON into the response body. +// It also sets the Content-Type as "application/json". +func (c *Context) JSON(data interface{}, err error) { + code := http.StatusOK + c.Error = err + bcode := ecode.Cause(err) + // TODO app allow 5xx? + /* + if bcode.Code() == -500 { + code = http.StatusServiceUnavailable + } + */ + writeStatusCode(c.Writer, bcode.Code()) + c.Render(code, render.JSON{ + Code: bcode.Code(), + Message: bcode.Message(), + Data: data, + }) +} + +// JSONMap serializes the given map as map JSON into the response body. +// It also sets the Content-Type as "application/json". +func (c *Context) JSONMap(data map[string]interface{}, err error) { + code := http.StatusOK + c.Error = err + bcode := ecode.Cause(err) + // TODO app allow 5xx? + /* + if bcode.Code() == -500 { + code = http.StatusServiceUnavailable + } + */ + writeStatusCode(c.Writer, bcode.Code()) + data["code"] = bcode.Code() + if _, ok := data["message"]; !ok { + data["message"] = bcode.Message() + } + c.Render(code, render.MapJSON(data)) +} + +// XML serializes the given struct as XML into the response body. +// It also sets the Content-Type as "application/xml". +func (c *Context) XML(data interface{}, err error) { + code := http.StatusOK + c.Error = err + bcode := ecode.Cause(err) + // TODO app allow 5xx? + /* + if bcode.Code() == -500 { + code = http.StatusServiceUnavailable + } + */ + writeStatusCode(c.Writer, bcode.Code()) + c.Render(code, render.XML{ + Code: bcode.Code(), + Message: bcode.Message(), + Data: data, + }) +} + +// Protobuf serializes the given struct as PB into the response body. +// It also sets the ContentType as "application/x-protobuf". +func (c *Context) Protobuf(data proto.Message, err error) { + var ( + bytes []byte + ) + + code := http.StatusOK + c.Error = err + bcode := ecode.Cause(err) + + any := new(types.Any) + if data != nil { + if bytes, err = proto.Marshal(data); err != nil { + c.Error = errors.WithStack(err) + return + } + any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data) + any.Value = bytes + } + writeStatusCode(c.Writer, bcode.Code()) + c.Render(code, render.PB{ + Code: int64(bcode.Code()), + Message: bcode.Message(), + Data: any, + }) +} + +// Bytes writes some data into the body stream and updates the HTTP code. +func (c *Context) Bytes(code int, contentType string, data ...[]byte) { + c.Render(code, render.Data{ + ContentType: contentType, + Data: data, + }) +} + +// String writes the given string into the response body. +func (c *Context) String(code int, format string, values ...interface{}) { + c.Render(code, render.String{Format: format, Data: values}) +} + +// Redirect returns a HTTP redirect to the specific location. +func (c *Context) Redirect(code int, location string) { + c.Render(-1, render.Redirect{ + Code: code, + Location: location, + Request: c.Request, + }) +} + +// BindWith bind req arg with parser. +func (c *Context) BindWith(obj interface{}, b binding.Binding) error { + return c.mustBindWith(obj, b) +} + +// Bind bind req arg with defult form binding. +func (c *Context) Bind(obj interface{}) error { + return c.mustBindWith(obj, binding.Form) +} + +// mustBindWith binds the passed struct pointer using the specified binding engine. +// It will abort the request with HTTP 400 if any error ocurrs. +// See the binding package. +func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) { + if err = b.Bind(c.Request, obj); err != nil { + c.Error = ecode.RequestErr + c.Render(http.StatusOK, render.JSON{ + Code: ecode.RequestErr.Code(), + Message: err.Error(), + Data: nil, + }) + c.Abort() + } + return +} + +func writeStatusCode(w http.ResponseWriter, ecode int) { + header := w.Header() + header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10)) +} diff --git a/pkg/net/http/blademaster/cors.go b/pkg/net/http/blademaster/cors.go new file mode 100644 index 000000000..f4e88a1db --- /dev/null +++ b/pkg/net/http/blademaster/cors.go @@ -0,0 +1,249 @@ +package blademaster + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/bilibili/Kratos/pkg/log" + + "github.com/pkg/errors" +) + +// CORSConfig represents all available options for the middleware. +type CORSConfig struct { + AllowAllOrigins bool + + // AllowedOrigins is a list of origins a cross-domain request can be executed from. + // If the special "*" value is present in the list, all origins will be allowed. + // Default value is [] + AllowOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It take the origin + // as argument and returns true if allowed or false otherwise. If this option is + // set, the content of AllowedOrigins is ignored. + AllowOriginFunc func(origin string) bool + + // AllowedMethods is a list of methods the client is allowed to use with + // cross-domain requests. Default value is simple methods (GET and POST) + AllowMethods []string + + // AllowedHeaders is list of non simple headers the client is allowed to use with + // cross-domain requests. + AllowHeaders []string + + // AllowCredentials indicates whether the request can include user credentials like + // cookies, HTTP authentication or client side SSL certificates. + AllowCredentials bool + + // ExposedHeaders indicates which headers are safe to expose to the API of a CORS + // API specification + ExposeHeaders []string + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached + MaxAge time.Duration +} + +type cors struct { + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header +} + +type converter func(string) string + +// Validate is check configuration of user defined. +func (c *CORSConfig) Validate() error { + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { + return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") + } + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + return errors.New("conflict settings: all origins disabled") + } + for _, origin := range c.AllowOrigins { + if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { + return errors.New("bad origin: origins must either be '*' or include http:// or https://") + } + } + return nil +} + +// CORS returns the location middleware with default configuration. +func CORS(allowOriginHosts []string) HandlerFunc { + config := &CORSConfig{ + AllowMethods: []string{"GET", "POST"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"}, + AllowCredentials: true, + MaxAge: time.Duration(0), + AllowOriginFunc: func(origin string) bool { + for _, host := range allowOriginHosts { + if strings.HasSuffix(strings.ToLower(origin), host) { + return true + } + } + return false + }, + } + return newCORS(config) +} + +// newCORS returns the location middleware with user-defined custom configuration. +func newCORS(config *CORSConfig) HandlerFunc { + if err := config.Validate(); err != nil { + panic(err.Error()) + } + cors := &cors{ + allowOriginFunc: config.AllowOriginFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + } + + return func(c *Context) { + cors.applyCORS(c) + } +} + +func (cors *cors) applyCORS(c *Context) { + origin := c.Request.Header.Get("Origin") + if len(origin) == 0 { + // request is not a CORS request + return + } + if !cors.validateOrigin(origin) { + log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin) + c.AbortWithStatus(http.StatusForbidden) + return + } + + if c.Request.Method == "OPTIONS" { + cors.handlePreflight(c) + defer c.AbortWithStatus(200) + } else { + cors.handleNormal(c) + } + + if !cors.allowAllOrigins { + header := c.Writer.Header() + header.Set("Access-Control-Allow-Origin", origin) + } +} + +func (cors *cors) validateOrigin(origin string) bool { + if cors.allowAllOrigins { + return true + } + for _, value := range cors.allowOrigins { + if value == origin { + return true + } + } + if cors.allowOriginFunc != nil { + return cors.allowOriginFunc(origin) + } + return false +} + +func (cors *cors) handlePreflight(c *Context) { + header := c.Writer.Header() + for key, value := range cors.preflightHeaders { + header[key] = value + } +} + +func (cors *cors) handleNormal(c *Context) { + header := c.Writer.Header() + for key, value := range cors.normalHeaders { + header[key] = value + } +} + +func generateNormalHeaders(c *CORSConfig) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + + // backport support for early browsers + if len(c.AllowMethods) > 0 { + allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) + value := strings.Join(allowMethods, ",") + headers.Set("Access-Control-Allow-Methods", value) + } + + if len(c.ExposeHeaders) > 0 { + exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) + headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + headers.Set("Vary", "Origin") + } + return headers +} + +func generatePreflightHeaders(c *CORSConfig) http.Header { + headers := make(http.Header) + if c.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + if len(c.AllowMethods) > 0 { + allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) + value := strings.Join(allowMethods, ",") + headers.Set("Access-Control-Allow-Methods", value) + } + if len(c.AllowHeaders) > 0 { + allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) + value := strings.Join(allowHeaders, ",") + headers.Set("Access-Control-Allow-Headers", value) + } + if c.MaxAge > time.Duration(0) { + value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) + headers.Set("Access-Control-Max-Age", value) + } + if c.AllowAllOrigins { + headers.Set("Access-Control-Allow-Origin", "*") + } else { + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + } + return headers +} + +func normalize(values []string) []string { + if values == nil { + return nil + } + distinctMap := make(map[string]bool, len(values)) + normalized := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + value = strings.ToLower(value) + if _, seen := distinctMap[value]; !seen { + normalized = append(normalized, value) + distinctMap[value] = true + } + } + return normalized +} + +func convert(s []string, c converter) []string { + var out []string + for _, i := range s { + out = append(out, c(i)) + } + return out +} diff --git a/pkg/net/http/blademaster/csrf.go b/pkg/net/http/blademaster/csrf.go new file mode 100644 index 000000000..a45825189 --- /dev/null +++ b/pkg/net/http/blademaster/csrf.go @@ -0,0 +1,69 @@ +package blademaster + +import ( + "net/url" + "regexp" + "strings" + + "github.com/bilibili/Kratos/pkg/log" +) + +func matchHostSuffix(suffix string) func(*url.URL) bool { + return func(uri *url.URL) bool { + return strings.HasSuffix(strings.ToLower(uri.Host), suffix) + } +} + +func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool { + return func(uri *url.URL) bool { + return pattern.MatchString(strings.ToLower(uri.String())) + } +} + +// CSRF returns the csrf middleware to prevent invalid cross site request. +// Only referer is checked currently. +func CSRF(allowHosts []string, allowPattern []string) HandlerFunc { + validations := []func(*url.URL) bool{} + + addHostSuffix := func(suffix string) { + validations = append(validations, matchHostSuffix(suffix)) + } + addPattern := func(pattern string) { + validations = append(validations, matchPattern(regexp.MustCompile(pattern))) + } + + for _, r := range allowHosts { + addHostSuffix(r) + } + for _, p := range allowPattern { + addPattern(p) + } + + return func(c *Context) { + referer := c.Request.Header.Get("Referer") + params := c.Request.Form + cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "") + if referer == "" { + if !cross { + return + } + log.V(5).Info("The request's Referer header is empty.") + c.AbortWithStatus(403) + return + } + illegal := true + if uri, err := url.Parse(referer); err == nil && uri.Host != "" { + for _, validate := range validations { + if validate(uri) { + illegal = false + break + } + } + } + if illegal { + log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer) + c.AbortWithStatus(403) + return + } + } +} diff --git a/pkg/net/http/blademaster/logger.go b/pkg/net/http/blademaster/logger.go new file mode 100644 index 000000000..58fb74ff1 --- /dev/null +++ b/pkg/net/http/blademaster/logger.go @@ -0,0 +1,69 @@ +package blademaster + +import ( + "fmt" + "strconv" + "time" + + "github.com/bilibili/Kratos/pkg/ecode" + "github.com/bilibili/Kratos/pkg/log" + "github.com/bilibili/Kratos/pkg/net/metadata" +) + +// Logger is logger middleware +func Logger() HandlerFunc { + const noUser = "no_user" + return func(c *Context) { + now := time.Now() + ip := metadata.String(c, metadata.RemoteIP) + req := c.Request + path := req.URL.Path + params := req.Form + var quota float64 + if deadline, ok := c.Context.Deadline(); ok { + quota = time.Until(deadline).Seconds() + } + + c.Next() + + err := c.Error + cerr := ecode.Cause(err) + dt := time.Since(now) + caller := metadata.String(c, metadata.Caller) + if caller == "" { + caller = noUser + } + + stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10)) + stats.Timing(caller, int64(dt/time.Millisecond), path[1:]) + + lf := log.Infov + errmsg := "" + isSlow := dt >= (time.Millisecond * 500) + if err != nil { + errmsg = err.Error() + lf = log.Errorv + if cerr.Code() > 0 { + lf = log.Warnv + } + } else { + if isSlow { + lf = log.Warnv + } + } + lf(c, + log.KVString("method", req.Method), + log.KVString("ip", ip), + log.KVString("user", caller), + log.KVString("path", path), + log.KVString("params", params.Encode()), + log.KVInt("ret", cerr.Code()), + log.KVString("msg", cerr.Message()), + log.KVString("stack", fmt.Sprintf("%+v", err)), + log.KVString("err", errmsg), + log.KVFloat64("timeout_quota", quota), + log.KVFloat64("ts", dt.Seconds()), + log.KVString("source", "http-access-log"), + ) + } +} diff --git a/pkg/net/http/blademaster/metadata.go b/pkg/net/http/blademaster/metadata.go index a868f4305..443e86f24 100644 --- a/pkg/net/http/blademaster/metadata.go +++ b/pkg/net/http/blademaster/metadata.go @@ -17,13 +17,14 @@ const ( _httpHeaderUser = "x1-bmspy-user" _httpHeaderColor = "x1-bmspy-color" _httpHeaderTimeout = "x1-bmspy-timeout" + _httpHeaderMirror = "x1-bmspy-mirror" _httpHeaderRemoteIP = "x-backend-bm-real-ip" _httpHeaderRemoteIPPort = "x-backend-bm-real-ipport" ) -// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true. +// mirror return true if x-bmspy-mirror in http header and its value is 1 or true. func mirror(req *http.Request) bool { - mirrorStr := req.Header.Get("x1-bilispy-mirror") + mirrorStr := req.Header.Get(_httpHeaderMirror) if mirrorStr == "" { return false } @@ -79,7 +80,7 @@ func timeout(req *http.Request) time.Duration { } // remoteIP implements a best effort algorithm to return the real client IP, it parses -// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. +// x-backend-bm-real-ip or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. func remoteIP(req *http.Request) (remote string) { if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" { diff --git a/pkg/net/http/blademaster/perf.go b/pkg/net/http/blademaster/perf.go new file mode 100644 index 000000000..361cee607 --- /dev/null +++ b/pkg/net/http/blademaster/perf.go @@ -0,0 +1,46 @@ +package blademaster + +import ( + "flag" + "net/http" + "net/http/pprof" + "os" + "sync" + + "github.com/bilibili/Kratos/pkg/conf/dsn" + + "github.com/pkg/errors" +) + +var ( + _perfOnce sync.Once + _perfDSN string +) + +func init() { + v := os.Getenv("HTTP_PERF") + if v == "" { + v = "tcp://0.0.0.0:2333" + } + flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.") +} + +func startPerf() { + _perfOnce.Do(func() { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + + go func() { + d, err := dsn.Parse(_perfDSN) + if err != nil { + panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err)) + } + if err := http.ListenAndServe(d.Host, mux); err != nil { + panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err)) + } + }() + }) +} diff --git a/pkg/net/http/blademaster/prometheus.go b/pkg/net/http/blademaster/prometheus.go new file mode 100644 index 000000000..68af2ee58 --- /dev/null +++ b/pkg/net/http/blademaster/prometheus.go @@ -0,0 +1,12 @@ +package blademaster + +import ( + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +func monitor() HandlerFunc { + return func(c *Context) { + h := promhttp.Handler() + h.ServeHTTP(c.Writer, c.Request) + } +} diff --git a/pkg/net/http/blademaster/recovery.go b/pkg/net/http/blademaster/recovery.go new file mode 100644 index 000000000..403d8a5d9 --- /dev/null +++ b/pkg/net/http/blademaster/recovery.go @@ -0,0 +1,32 @@ +package blademaster + +import ( + "fmt" + "net/http/httputil" + "os" + "runtime" + + "github.com/bilibili/Kratos/pkg/log" +) + +// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. +func Recovery() HandlerFunc { + return func(c *Context) { + defer func() { + var rawReq []byte + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + if c.Request != nil { + rawReq, _ = httputil.DumpRequest(c.Request, false) + } + pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf) + fmt.Fprintf(os.Stderr, pl) + log.Error(pl) + c.AbortWithStatus(500) + } + }() + c.Next() + } +} diff --git a/pkg/net/http/blademaster/render/data.go b/pkg/net/http/blademaster/render/data.go new file mode 100644 index 000000000..d602350b0 --- /dev/null +++ b/pkg/net/http/blademaster/render/data.go @@ -0,0 +1,30 @@ +package render + +import ( + "net/http" + + "github.com/pkg/errors" +) + +// Data common bytes struct. +type Data struct { + ContentType string + Data [][]byte +} + +// Render (Data) writes data with custom ContentType. +func (r Data) Render(w http.ResponseWriter) (err error) { + r.WriteContentType(w) + for _, d := range r.Data { + if _, err = w.Write(d); err != nil { + err = errors.WithStack(err) + return + } + } + return +} + +// WriteContentType writes data with custom ContentType. +func (r Data) WriteContentType(w http.ResponseWriter) { + writeContentType(w, []string{r.ContentType}) +} diff --git a/pkg/net/http/blademaster/render/json.go b/pkg/net/http/blademaster/render/json.go new file mode 100644 index 000000000..5a5f23bff --- /dev/null +++ b/pkg/net/http/blademaster/render/json.go @@ -0,0 +1,58 @@ +package render + +import ( + "encoding/json" + "net/http" + + "github.com/pkg/errors" +) + +var jsonContentType = []string{"application/json; charset=utf-8"} + +// JSON common json struct. +type JSON struct { + Code int `json:"code"` + Message string `json:"message"` + TTL int `json:"ttl"` + Data interface{} `json:"data,omitempty"` +} + +func writeJSON(w http.ResponseWriter, obj interface{}) (err error) { + var jsonBytes []byte + writeContentType(w, jsonContentType) + if jsonBytes, err = json.Marshal(obj); err != nil { + err = errors.WithStack(err) + return + } + if _, err = w.Write(jsonBytes); err != nil { + err = errors.WithStack(err) + } + return +} + +// Render (JSON) writes data with json ContentType. +func (r JSON) Render(w http.ResponseWriter) error { + // FIXME(zhoujiahui): the TTL field will be configurable in the future + if r.TTL <= 0 { + r.TTL = 1 + } + return writeJSON(w, r) +} + +// WriteContentType write json ContentType. +func (r JSON) WriteContentType(w http.ResponseWriter) { + writeContentType(w, jsonContentType) +} + +// MapJSON common map json struct. +type MapJSON map[string]interface{} + +// Render (MapJSON) writes data with json ContentType. +func (m MapJSON) Render(w http.ResponseWriter) error { + return writeJSON(w, m) +} + +// WriteContentType write json ContentType. +func (m MapJSON) WriteContentType(w http.ResponseWriter) { + writeContentType(w, jsonContentType) +} diff --git a/pkg/net/http/blademaster/render/protobuf.go b/pkg/net/http/blademaster/render/protobuf.go new file mode 100644 index 000000000..4664f2b5f --- /dev/null +++ b/pkg/net/http/blademaster/render/protobuf.go @@ -0,0 +1,38 @@ +package render + +import ( + "net/http" + + "github.com/gogo/protobuf/proto" + "github.com/pkg/errors" +) + +var pbContentType = []string{"application/x-protobuf"} + +// Render (PB) writes data with protobuf ContentType. +func (r PB) Render(w http.ResponseWriter) error { + if r.TTL <= 0 { + r.TTL = 1 + } + return writePB(w, r) +} + +// WriteContentType write protobuf ContentType. +func (r PB) WriteContentType(w http.ResponseWriter) { + writeContentType(w, pbContentType) +} + +func writePB(w http.ResponseWriter, obj PB) (err error) { + var pbBytes []byte + writeContentType(w, pbContentType) + + if pbBytes, err = proto.Marshal(&obj); err != nil { + err = errors.WithStack(err) + return + } + + if _, err = w.Write(pbBytes); err != nil { + err = errors.WithStack(err) + } + return +} diff --git a/pkg/net/http/blademaster/render/redirect.go b/pkg/net/http/blademaster/render/redirect.go new file mode 100644 index 000000000..73e516d65 --- /dev/null +++ b/pkg/net/http/blademaster/render/redirect.go @@ -0,0 +1,26 @@ +package render + +import ( + "net/http" + + "github.com/pkg/errors" +) + +// Redirect render for redirect to specified location. +type Redirect struct { + Code int + Request *http.Request + Location string +} + +// Render (Redirect) redirect to specified location. +func (r Redirect) Render(w http.ResponseWriter) error { + if (r.Code < 300 || r.Code > 308) && r.Code != 201 { + return errors.Errorf("Cannot redirect with status code %d", r.Code) + } + http.Redirect(w, r.Request, r.Location, r.Code) + return nil +} + +// WriteContentType noneContentType. +func (r Redirect) WriteContentType(http.ResponseWriter) {} diff --git a/pkg/net/http/blademaster/render/render.go b/pkg/net/http/blademaster/render/render.go new file mode 100644 index 000000000..13188637e --- /dev/null +++ b/pkg/net/http/blademaster/render/render.go @@ -0,0 +1,30 @@ +package render + +import ( + "net/http" +) + +// Render http reponse render. +type Render interface { + // Render render it to http response writer. + Render(http.ResponseWriter) error + // WriteContentType write content-type to http response writer. + WriteContentType(w http.ResponseWriter) +} + +var ( + _ Render = JSON{} + _ Render = MapJSON{} + _ Render = XML{} + _ Render = String{} + _ Render = Redirect{} + _ Render = Data{} + _ Render = PB{} +) + +func writeContentType(w http.ResponseWriter, value []string) { + header := w.Header() + if val := header["Content-Type"]; len(val) == 0 { + header["Content-Type"] = value + } +} diff --git a/pkg/net/http/blademaster/render/render.pb.go b/pkg/net/http/blademaster/render/render.pb.go new file mode 100644 index 000000000..bb5390e98 --- /dev/null +++ b/pkg/net/http/blademaster/render/render.pb.go @@ -0,0 +1,89 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: pb.proto + +/* +Package render is a generated protocol buffer package. + +It is generated from these files: + pb.proto + +It has these top-level messages: + PB +*/ +package render + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/gogo/protobuf/types" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type PB struct { + Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` + Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` + TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"` + Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"` +} + +func (m *PB) Reset() { *m = PB{} } +func (m *PB) String() string { return proto.CompactTextString(m) } +func (*PB) ProtoMessage() {} +func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} } + +func (m *PB) GetCode() int64 { + if m != nil { + return m.Code + } + return 0 +} + +func (m *PB) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + +func (m *PB) GetTTL() uint64 { + if m != nil { + return m.TTL + } + return 0 +} + +func (m *PB) GetData() *google_protobuf.Any { + if m != nil { + return m.Data + } + return nil +} + +func init() { + proto.RegisterType((*PB)(nil), "render.PB") +} + +func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) } + +var fileDescriptorPb = []byte{ + // 154 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b, + 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf, + 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94, + 0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30, + 0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, + 0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48, + 0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92, + 0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66, + 0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00, + 0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00, +} diff --git a/pkg/net/http/blademaster/render/render.proto b/pkg/net/http/blademaster/render/render.proto new file mode 100644 index 000000000..e3f53870f --- /dev/null +++ b/pkg/net/http/blademaster/render/render.proto @@ -0,0 +1,14 @@ +// use under command to generate pb.pb.go +// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto +syntax = "proto3"; +package render; + +import "google/protobuf/any.proto"; +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +message PB { + int64 Code = 1; + string Message = 2; + uint64 TTL = 3; + google.protobuf.Any Data = 4; +} \ No newline at end of file diff --git a/pkg/net/http/blademaster/render/string.go b/pkg/net/http/blademaster/render/string.go new file mode 100644 index 000000000..4112b5713 --- /dev/null +++ b/pkg/net/http/blademaster/render/string.go @@ -0,0 +1,40 @@ +package render + +import ( + "fmt" + "io" + "net/http" + + "github.com/pkg/errors" +) + +var plainContentType = []string{"text/plain; charset=utf-8"} + +// String common string struct. +type String struct { + Format string + Data []interface{} +} + +// Render (String) writes data with custom ContentType. +func (r String) Render(w http.ResponseWriter) error { + return writeString(w, r.Format, r.Data) +} + +// WriteContentType writes string with text/plain ContentType. +func (r String) WriteContentType(w http.ResponseWriter) { + writeContentType(w, plainContentType) +} + +func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) { + writeContentType(w, plainContentType) + if len(data) > 0 { + _, err = fmt.Fprintf(w, format, data...) + } else { + _, err = io.WriteString(w, format) + } + if err != nil { + err = errors.WithStack(err) + } + return +} diff --git a/pkg/net/http/blademaster/render/xml.go b/pkg/net/http/blademaster/render/xml.go new file mode 100644 index 000000000..8837c582c --- /dev/null +++ b/pkg/net/http/blademaster/render/xml.go @@ -0,0 +1,31 @@ +package render + +import ( + "encoding/xml" + "net/http" + + "github.com/pkg/errors" +) + +// XML common xml struct. +type XML struct { + Code int + Message string + Data interface{} +} + +var xmlContentType = []string{"application/xml; charset=utf-8"} + +// Render (XML) writes data with xml ContentType. +func (r XML) Render(w http.ResponseWriter) (err error) { + r.WriteContentType(w) + if err = xml.NewEncoder(w).Encode(r.Data); err != nil { + err = errors.WithStack(err) + } + return +} + +// WriteContentType write xml ContentType. +func (r XML) WriteContentType(w http.ResponseWriter) { + writeContentType(w, xmlContentType) +} diff --git a/pkg/net/http/blademaster/routergroup.go b/pkg/net/http/blademaster/routergroup.go new file mode 100644 index 000000000..28d09a805 --- /dev/null +++ b/pkg/net/http/blademaster/routergroup.go @@ -0,0 +1,166 @@ +package blademaster + +import ( + "regexp" +) + +// IRouter http router framework interface. +type IRouter interface { + IRoutes + Group(string, ...HandlerFunc) *RouterGroup +} + +// IRoutes http router interface. +type IRoutes interface { + UseFunc(...HandlerFunc) IRoutes + Use(...Handler) IRoutes + + Handle(string, string, ...HandlerFunc) IRoutes + HEAD(string, ...HandlerFunc) IRoutes + GET(string, ...HandlerFunc) IRoutes + POST(string, ...HandlerFunc) IRoutes + PUT(string, ...HandlerFunc) IRoutes + DELETE(string, ...HandlerFunc) IRoutes +} + +// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix +// and an array of handlers (middleware). +type RouterGroup struct { + Handlers []HandlerFunc + basePath string + engine *Engine + root bool + baseConfig *MethodConfig +} + +var _ IRouter = &RouterGroup{} + +// Use adds middleware to the group, see example code in doc. +func (group *RouterGroup) Use(middleware ...Handler) IRoutes { + for _, m := range middleware { + group.Handlers = append(group.Handlers, m.ServeHTTP) + } + return group.returnObj() +} + +// UseFunc adds middleware to the group, see example code in doc. +func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes { + group.Handlers = append(group.Handlers, middleware...) + return group.returnObj() +} + +// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix. +// For example, all the routes that use a common middlware for authorization could be grouped. +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { + return &RouterGroup{ + Handlers: group.combineHandlers(handlers), + basePath: group.calculateAbsolutePath(relativePath), + engine: group.engine, + root: false, + } +} + +// SetMethodConfig is used to set config on specified method +func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup { + group.baseConfig = config + return group +} + +// BasePath router group base path. +func (group *RouterGroup) BasePath() string { + return group.basePath +} + +func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { + absolutePath := group.calculateAbsolutePath(relativePath) + injections := group.injections(relativePath) + handlers = group.combineHandlers(injections, handlers) + group.engine.addRoute(httpMethod, absolutePath, handlers...) + if group.baseConfig != nil { + group.engine.SetMethodConfig(absolutePath, group.baseConfig) + } + return group.returnObj() +} + +// Handle registers a new request handle and middleware with the given path and method. +// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes. +// See the example code in doc. +// +// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut +// functions can be used. +// +// This function is intended for bulk loading and to allow the usage of less +// frequently used, non-standardized or custom methods (e.g. for internal +// communication with a proxy). +func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { + if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil { + panic("http method " + httpMethod + " is not valid") + } + return group.handle(httpMethod, relativePath, handlers...) +} + +// HEAD is a shortcut for router.Handle("HEAD", path, handle). +func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle("HEAD", relativePath, handlers...) +} + +// GET is a shortcut for router.Handle("GET", path, handle). +func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle("GET", relativePath, handlers...) +} + +// POST is a shortcut for router.Handle("POST", path, handle). +func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle("POST", relativePath, handlers...) +} + +// PUT is a shortcut for router.Handle("PUT", path, handle). +func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle("PUT", relativePath, handlers...) +} + +// DELETE is a shortcut for router.Handle("DELETE", path, handle). +func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle("DELETE", relativePath, handlers...) +} + +func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc { + finalSize := len(group.Handlers) + for _, handlers := range handlerGroups { + finalSize += len(handlers) + } + if finalSize >= int(_abortIndex) { + panic("too many handlers") + } + mergedHandlers := make([]HandlerFunc, finalSize) + copy(mergedHandlers, group.Handlers) + position := len(group.Handlers) + for _, handlers := range handlerGroups { + copy(mergedHandlers[position:], handlers) + position += len(handlers) + } + return mergedHandlers +} + +func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { + return joinPaths(group.basePath, relativePath) +} + +func (group *RouterGroup) returnObj() IRoutes { + if group.root { + return group.engine + } + return group +} + +// injections is +func (group *RouterGroup) injections(relativePath string) []HandlerFunc { + absPath := group.calculateAbsolutePath(relativePath) + for _, injection := range group.engine.injections { + if !injection.pattern.MatchString(absPath) { + continue + } + return injection.handlers + } + return nil +} diff --git a/pkg/net/http/blademaster/server.go b/pkg/net/http/blademaster/server.go new file mode 100644 index 000000000..c96513ba3 --- /dev/null +++ b/pkg/net/http/blademaster/server.go @@ -0,0 +1,445 @@ +package blademaster + +import ( + "context" + "flag" + "fmt" + "net" + "net/http" + "os" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bilibili/Kratos/pkg/conf/dsn" + "github.com/bilibili/Kratos/pkg/log" + "github.com/bilibili/Kratos/pkg/net/ip" + "github.com/bilibili/Kratos/pkg/net/metadata" + "github.com/bilibili/Kratos/pkg/stat" + xtime "github.com/bilibili/Kratos/pkg/time" + + "github.com/pkg/errors" +) + +const ( + defaultMaxMemory = 32 << 20 // 32 MB +) + +var ( + _ IRouter = &Engine{} + stats = stat.HTTPServer + + _httpDSN string +) + +func init() { + addFlag(flag.CommandLine) +} + +func addFlag(fs *flag.FlagSet) { + v := os.Getenv("HTTP") + if v == "" { + v = "tcp://0.0.0.0:8000/?timeout=1s" + } + fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.") +} + +func parseDSN(rawdsn string) *ServerConfig { + conf := new(ServerConfig) + d, err := dsn.Parse(rawdsn) + if err != nil { + panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) + } + if _, err = d.Bind(conf); err != nil { + panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) + } + return conf +} + +// Handler responds to an HTTP request. +type Handler interface { + ServeHTTP(c *Context) +} + +// HandlerFunc http request handler function. +type HandlerFunc func(*Context) + +// ServeHTTP calls f(ctx). +func (f HandlerFunc) ServeHTTP(c *Context) { + f(c) +} + +// ServerConfig is the bm server config model +type ServerConfig struct { + Network string `dsn:"network"` + // FIXME: rename to Address + Addr string `dsn:"address"` + Timeout xtime.Duration `dsn:"query.timeout"` + ReadTimeout xtime.Duration `dsn:"query.readTimeout"` + WriteTimeout xtime.Duration `dsn:"query.writeTimeout"` +} + +// MethodConfig is +type MethodConfig struct { + Timeout xtime.Duration +} + +// Start listen and serve bm engine by given DSN. +func (engine *Engine) Start() error { + conf := engine.conf + l, err := net.Listen(conf.Network, conf.Addr) + if err != nil { + errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr) + return err + } + + log.Info("blademaster: start http listen addr: %s", conf.Addr) + server := &http.Server{ + ReadTimeout: time.Duration(conf.ReadTimeout), + WriteTimeout: time.Duration(conf.WriteTimeout), + } + go func() { + if err := engine.RunServer(server, l); err != nil { + if errors.Cause(err) == http.ErrServerClosed { + log.Info("blademaster: server closed") + return + } + panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l)) + } + }() + + return nil +} + +// Engine is the framework's instance, it contains the muxer, middleware and configuration settings. +// Create an instance of Engine, by using New() or Default() +type Engine struct { + RouterGroup + + lock sync.RWMutex + conf *ServerConfig + + address string + + mux *http.ServeMux // http mux router + server atomic.Value // store *http.Server + metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata + + pcLock sync.RWMutex + methodConfigs map[string]*MethodConfig + + injections []injection +} + +type injection struct { + pattern *regexp.Regexp + handlers []HandlerFunc +} + +// New returns a new blank Engine instance without any middleware attached. +// +// Deprecated: please use NewServer. +func New() *Engine { + engine := &Engine{ + RouterGroup: RouterGroup{ + Handlers: nil, + basePath: "/", + root: true, + }, + address: ip.InternalIP(), + conf: &ServerConfig{ + Timeout: xtime.Duration(time.Second), + }, + mux: http.NewServeMux(), + metastore: make(map[string]map[string]interface{}), + methodConfigs: make(map[string]*MethodConfig), + injections: make([]injection, 0), + } + engine.RouterGroup.engine = engine + // NOTE add prometheus monitor location + engine.addRoute("GET", "/metrics", monitor()) + engine.addRoute("GET", "/metadata", engine.metadata()) + startPerf() + return engine +} + +// NewServer returns a new blank Engine instance without any middleware attached. +func NewServer(conf *ServerConfig) *Engine { + if conf == nil { + if !flag.Parsed() { + fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n") + } + conf = parseDSN(_httpDSN) + } else { + fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n") + } + + engine := &Engine{ + RouterGroup: RouterGroup{ + Handlers: nil, + basePath: "/", + root: true, + }, + address: ip.InternalIP(), + mux: http.NewServeMux(), + metastore: make(map[string]map[string]interface{}), + methodConfigs: make(map[string]*MethodConfig), + } + if err := engine.SetConfig(conf); err != nil { + panic(err) + } + engine.RouterGroup.engine = engine + // NOTE add prometheus monitor location + engine.addRoute("GET", "/metrics", monitor()) + engine.addRoute("GET", "/metadata", engine.metadata()) + startPerf() + return engine +} + +// SetMethodConfig is used to set config on specified path +func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) { + engine.pcLock.Lock() + engine.methodConfigs[path] = mc + engine.pcLock.Unlock() +} + +// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached. +func DefaultServer(conf *ServerConfig) *Engine { + engine := NewServer(conf) + engine.Use(Recovery(), Trace(), Logger()) + return engine +} + +// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached. +// +// Deprecated: please use DefaultServer. +func Default() *Engine { + engine := New() + engine.Use(Recovery(), Trace(), Logger()) + return engine +} + +func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) { + if path[0] != '/' { + panic("blademaster: path must begin with '/'") + } + if method == "" { + panic("blademaster: HTTP method can not be empty") + } + if len(handlers) == 0 { + panic("blademaster: there must be at least one handler") + } + if _, ok := engine.metastore[path]; !ok { + engine.metastore[path] = make(map[string]interface{}) + } + engine.metastore[path]["method"] = method + engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { + c := &Context{ + Context: nil, + engine: engine, + index: -1, + handlers: nil, + Keys: nil, + method: "", + Error: nil, + } + + c.Request = req + c.Writer = w + c.handlers = handlers + c.method = method + + engine.handleContext(c) + }) +} + +// SetConfig is used to set the engine configuration. +// Only the valid config will be loaded. +func (engine *Engine) SetConfig(conf *ServerConfig) (err error) { + if conf.Timeout <= 0 { + return errors.New("blademaster: config timeout must greater than 0") + } + if conf.Network == "" { + conf.Network = "tcp" + } + engine.lock.Lock() + engine.conf = conf + engine.lock.Unlock() + return +} + +func (engine *Engine) methodConfig(path string) *MethodConfig { + engine.pcLock.RLock() + mc := engine.methodConfigs[path] + engine.pcLock.RUnlock() + return mc +} + +func (engine *Engine) handleContext(c *Context) { + var cancel func() + req := c.Request + ctype := req.Header.Get("Content-Type") + switch { + case strings.Contains(ctype, "multipart/form-data"): + req.ParseMultipartForm(defaultMaxMemory) + default: + req.ParseForm() + } + // get derived timeout from http request header, + // compare with the engine configured, + // and use the minimum one + engine.lock.RLock() + tm := time.Duration(engine.conf.Timeout) + engine.lock.RUnlock() + // the method config is preferred + if pc := engine.methodConfig(c.Request.URL.Path); pc != nil { + tm = time.Duration(pc.Timeout) + } + if ctm := timeout(req); ctm > 0 && tm > ctm { + tm = ctm + } + md := metadata.MD{ + metadata.Color: color(req), + metadata.RemoteIP: remoteIP(req), + metadata.RemotePort: remotePort(req), + metadata.Caller: caller(req), + metadata.Mirror: mirror(req), + } + ctx := metadata.NewContext(context.Background(), md) + if tm > 0 { + c.Context, cancel = context.WithTimeout(ctx, tm) + } else { + c.Context, cancel = context.WithCancel(ctx) + } + defer cancel() + c.Next() +} + +// Router return a http.Handler for using http.ListenAndServe() directly. +func (engine *Engine) Router() http.Handler { + return engine.mux +} + +// Server is used to load stored http server. +func (engine *Engine) Server() *http.Server { + s, ok := engine.server.Load().(*http.Server) + if !ok { + return nil + } + return s +} + +// Shutdown the http server without interrupting active connections. +func (engine *Engine) Shutdown(ctx context.Context) error { + server := engine.Server() + if server == nil { + return errors.New("blademaster: no server") + } + return errors.WithStack(server.Shutdown(ctx)) +} + +// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be +// included in the handlers chain for every single request. Even 404, 405, static files... +// For example, this is the right place for a logger or error management middleware. +func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes { + engine.RouterGroup.UseFunc(middleware...) + return engine +} + +// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be +// included in the handlers chain for every single request. Even 404, 405, static files... +// For example, this is the right place for a logger or error management middleware. +func (engine *Engine) Use(middleware ...Handler) IRoutes { + engine.RouterGroup.Use(middleware...) + return engine +} + +// Ping is used to set the general HTTP ping handler. +func (engine *Engine) Ping(handler HandlerFunc) { + engine.GET("/monitor/ping", handler) +} + +// Register is used to export metadata to discovery. +func (engine *Engine) Register(handler HandlerFunc) { + engine.GET("/register", handler) +} + +// Run attaches the router to a http.Server and starts listening and serving HTTP requests. +// It is a shortcut for http.ListenAndServe(addr, router) +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) Run(addr ...string) (err error) { + address := resolveAddress(addr) + server := &http.Server{ + Addr: address, + Handler: engine.mux, + } + engine.server.Store(server) + if err = server.ListenAndServe(); err != nil { + err = errors.Wrapf(err, "addrs: %v", addr) + } + return +} + +// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests. +// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router) +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { + server := &http.Server{ + Addr: addr, + Handler: engine.mux, + } + engine.server.Store(server) + if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { + err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile) + } + return +} + +// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests +// through the specified unix socket (ie. a file). +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) RunUnix(file string) (err error) { + os.Remove(file) + listener, err := net.Listen("unix", file) + if err != nil { + err = errors.Wrapf(err, "unix: %s", file) + return + } + defer listener.Close() + server := &http.Server{ + Handler: engine.mux, + } + engine.server.Store(server) + if err = server.Serve(listener); err != nil { + err = errors.Wrapf(err, "unix: %s", file) + } + return +} + +// RunServer will serve and start listening HTTP requests by given server and listener. +// Note: this method will block the calling goroutine indefinitely unless an error happens. +func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) { + server.Handler = engine.mux + engine.server.Store(server) + if err = server.Serve(l); err != nil { + err = errors.Wrapf(err, "listen server: %+v/%+v", server, l) + return + } + return +} + +func (engine *Engine) metadata() HandlerFunc { + return func(c *Context) { + c.JSON(engine.metastore, nil) + } +} + +// Inject is +func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) { + engine.injections = append(engine.injections, injection{ + pattern: regexp.MustCompile(pattern), + handlers: handlers, + }) +} diff --git a/pkg/net/http/blademaster/trace.go b/pkg/net/http/blademaster/trace.go index b792246bb..f607f18ca 100644 --- a/pkg/net/http/blademaster/trace.go +++ b/pkg/net/http/blademaster/trace.go @@ -4,12 +4,43 @@ import ( "io" "net/http" "net/http/httptrace" + "strconv" + "github.com/bilibili/Kratos/pkg/net/metadata" "github.com/bilibili/Kratos/pkg/net/trace" ) const _defaultComponentName = "net/http" +// Trace is trace middleware +func Trace() HandlerFunc { + return func(c *Context) { + // handle http request + // get derived trace from http request header + t, err := trace.Extract(trace.HTTPFormat, c.Request.Header) + if err != nil { + var opts []trace.Option + if ok, _ := strconv.ParseBool(trace.KratosTraceDebug); ok { + opts = append(opts, trace.EnableDebug()) + } + t = trace.New(c.Request.URL.Path, opts...) + } + t.SetTitle(c.Request.URL.Path) + t.SetTag(trace.String(trace.TagComponent, _defaultComponentName)) + t.SetTag(trace.String(trace.TagHTTPMethod, c.Request.Method)) + t.SetTag(trace.String(trace.TagHTTPURL, c.Request.URL.String())) + t.SetTag(trace.String(trace.TagSpanKind, "server")) + // business tag + t.SetTag(trace.String("caller", metadata.String(c.Context, metadata.Caller))) + // export trace id to user. + // TODO(zhoujiahui): trace package should be updated + // c.Writer.Header().Set(trace.KratosTraceID, t.TraceID()) + c.Context = trace.NewContext(c.Context, t) + c.Next() + t.Finish(&c.Error) + } +} + type closeTracker struct { io.ReadCloser tr trace.Trace diff --git a/pkg/net/http/blademaster/utils.go b/pkg/net/http/blademaster/utils.go new file mode 100644 index 000000000..54da96d3b --- /dev/null +++ b/pkg/net/http/blademaster/utils.go @@ -0,0 +1,42 @@ +package blademaster + +import ( + "os" + "path" +) + +func lastChar(str string) uint8 { + if str == "" { + panic("The length of the string can't be 0") + } + return str[len(str)-1] +} + +func joinPaths(absolutePath, relativePath string) string { + if relativePath == "" { + return absolutePath + } + + finalPath := path.Join(absolutePath, relativePath) + appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' + if appendSlash { + return finalPath + "/" + } + return finalPath +} + +func resolveAddress(addr []string) string { + switch len(addr) { + case 0: + if port := os.Getenv("PORT"); port != "" { + //debugPrint("Environment variable PORT=\"%s\"", port) + return ":" + port + } + //debugPrint("Environment variable PORT is undefined. Using port :8080 by default") + return ":8080" + case 1: + return addr[0] + default: + panic("too much parameters") + } +}