diff --git a/pkg/net/http/blademaster/binding/binding.go b/pkg/net/http/blademaster/binding/binding.go
new file mode 100644
index 000000000..46f1c770d
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/binding.go
@@ -0,0 +1,85 @@
+package binding
+
+import (
+ "net/http"
+ "strings"
+
+ "gopkg.in/go-playground/validator.v9"
+)
+
+// MIME
+const (
+ MIMEJSON = "application/json"
+ MIMEHTML = "text/html"
+ MIMEXML = "application/xml"
+ MIMEXML2 = "text/xml"
+ MIMEPlain = "text/plain"
+ MIMEPOSTForm = "application/x-www-form-urlencoded"
+ MIMEMultipartPOSTForm = "multipart/form-data"
+)
+
+// Binding http binding request interface.
+type Binding interface {
+ Name() string
+ Bind(*http.Request, interface{}) error
+}
+
+// StructValidator http validator interface.
+type StructValidator interface {
+ // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
+ // If the received type is not a struct, any validation should be skipped and nil must be returned.
+ // If the received type is a struct or pointer to a struct, the validation should be performed.
+ // If the struct is not valid or the validation itself fails, a descriptive error should be returned.
+ // Otherwise nil must be returned.
+ ValidateStruct(interface{}) error
+
+ // RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
+ // NOTE: if the key already exists, the previous validation function will be replaced.
+ // NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
+ RegisterValidation(string, validator.Func) error
+}
+
+// Validator default validator.
+var Validator StructValidator = &defaultValidator{}
+
+// Binding
+var (
+ JSON = jsonBinding{}
+ XML = xmlBinding{}
+ Form = formBinding{}
+ Query = queryBinding{}
+ FormPost = formPostBinding{}
+ FormMultipart = formMultipartBinding{}
+)
+
+// Default get by binding type by method and contexttype.
+func Default(method, contentType string) Binding {
+ if method == "GET" {
+ return Form
+ }
+
+ contentType = stripContentTypeParam(contentType)
+ switch contentType {
+ case MIMEJSON:
+ return JSON
+ case MIMEXML, MIMEXML2:
+ return XML
+ default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
+ return Form
+ }
+}
+
+func validate(obj interface{}) error {
+ if Validator == nil {
+ return nil
+ }
+ return Validator.ValidateStruct(obj)
+}
+
+func stripContentTypeParam(contentType string) string {
+ i := strings.Index(contentType, ";")
+ if i != -1 {
+ contentType = contentType[:i]
+ }
+ return contentType
+}
diff --git a/pkg/net/http/blademaster/binding/binding_test.go b/pkg/net/http/blademaster/binding/binding_test.go
new file mode 100644
index 000000000..8a5da8a6e
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/binding_test.go
@@ -0,0 +1,342 @@
+package binding
+
+import (
+ "bytes"
+ "mime/multipart"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type FooStruct struct {
+ Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"`
+}
+
+type FooBarStruct struct {
+ FooStruct
+ Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"`
+ Slice []string `form:"slice" validate:"max=10"`
+}
+
+type ComplexDefaultStruct struct {
+ Int int `form:"int" default:"999"`
+ String string `form:"string" default:"default-string"`
+ Bool bool `form:"bool" default:"false"`
+ Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"`
+ Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"`
+}
+
+type Int8SliceStruct struct {
+ State []int8 `form:"state,split"`
+}
+
+type Int64SliceStruct struct {
+ State []int64 `form:"state,split"`
+}
+
+type StringSliceStruct struct {
+ State []string `form:"state,split"`
+}
+
+func TestBindingDefault(t *testing.T) {
+ assert.Equal(t, Default("GET", ""), Form)
+ assert.Equal(t, Default("GET", MIMEJSON), Form)
+ assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form)
+
+ assert.Equal(t, Default("POST", MIMEJSON), JSON)
+ assert.Equal(t, Default("PUT", MIMEJSON), JSON)
+
+ assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON)
+ assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON)
+
+ assert.Equal(t, Default("POST", MIMEXML), XML)
+ assert.Equal(t, Default("PUT", MIMEXML2), XML)
+
+ assert.Equal(t, Default("POST", MIMEPOSTForm), Form)
+ assert.Equal(t, Default("PUT", MIMEPOSTForm), Form)
+
+ assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form)
+ assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form)
+
+ assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form)
+ assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form)
+
+}
+
+func TestStripContentType(t *testing.T) {
+ c1 := "application/vnd.mozilla.xul+xml"
+ c2 := "application/vnd.mozilla.xul+xml; charset=utf-8"
+ assert.Equal(t, stripContentTypeParam(c1), c1)
+ assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml")
+}
+
+func TestBindInt8Form(t *testing.T) {
+ params := "state=1,2,3"
+ req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q := new(Int8SliceStruct)
+ Form.Bind(req, q)
+ assert.EqualValues(t, []int8{1, 2, 3}, q.State)
+
+ params = "state=1,2,3,256"
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(Int8SliceStruct)
+ assert.Error(t, Form.Bind(req, q))
+
+ params = "state="
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(Int8SliceStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Len(t, q.State, 0)
+
+ params = "state=1,,2"
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(Int8SliceStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.EqualValues(t, []int8{1, 2}, q.State)
+}
+
+func TestBindInt64Form(t *testing.T) {
+ params := "state=1,2,3"
+ req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q := new(Int64SliceStruct)
+ Form.Bind(req, q)
+ assert.EqualValues(t, []int64{1, 2, 3}, q.State)
+
+ params = "state="
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(Int64SliceStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Len(t, q.State, 0)
+}
+
+func TestBindStringForm(t *testing.T) {
+ params := "state=1,2,3"
+ req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q := new(StringSliceStruct)
+ Form.Bind(req, q)
+ assert.EqualValues(t, []string{"1", "2", "3"}, q.State)
+
+ params = "state="
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(StringSliceStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Len(t, q.State, 0)
+
+ params = "state=p,,p"
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(StringSliceStruct)
+ Form.Bind(req, q)
+ assert.EqualValues(t, []string{"p", "p"}, q.State)
+}
+
+func TestBindingJSON(t *testing.T) {
+ testBodyBinding(t,
+ JSON, "json",
+ "/", "/",
+ `{"foo": "bar"}`, `{"bar": "foo"}`)
+}
+
+func TestBindingForm(t *testing.T) {
+ testFormBinding(t, "POST",
+ "/", "/",
+ "foo=bar&bar=foo&slice=a&slice=b", "bar2=foo")
+}
+
+func TestBindingForm2(t *testing.T) {
+ testFormBinding(t, "GET",
+ "/?foo=bar&bar=foo", "/?bar2=foo",
+ "", "")
+}
+
+func TestBindingQuery(t *testing.T) {
+ testQueryBinding(t, "POST",
+ "/?foo=bar&bar=foo", "/",
+ "foo=unused", "bar2=foo")
+}
+
+func TestBindingQuery2(t *testing.T) {
+ testQueryBinding(t, "GET",
+ "/?foo=bar&bar=foo", "/?bar2=foo",
+ "foo=unused", "")
+}
+
+func TestBindingXML(t *testing.T) {
+ testBodyBinding(t,
+ XML, "xml",
+ "/", "/",
+ "", "")
+}
+
+func createFormPostRequest() *http.Request {
+ req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
+ req.Header.Set("Content-Type", MIMEPOSTForm)
+ return req
+}
+
+func createFormMultipartRequest() *http.Request {
+ boundary := "--testboundary"
+ body := new(bytes.Buffer)
+ mw := multipart.NewWriter(body)
+ defer mw.Close()
+
+ mw.SetBoundary(boundary)
+ mw.WriteField("foo", "bar")
+ mw.WriteField("bar", "foo")
+ req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
+ req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
+ return req
+}
+
+func TestBindingFormPost(t *testing.T) {
+ req := createFormPostRequest()
+ var obj FooBarStruct
+ FormPost.Bind(req, &obj)
+
+ assert.Equal(t, obj.Foo, "bar")
+ assert.Equal(t, obj.Bar, "foo")
+}
+
+func TestBindingFormMultipart(t *testing.T) {
+ req := createFormMultipartRequest()
+ var obj FooBarStruct
+ FormMultipart.Bind(req, &obj)
+
+ assert.Equal(t, obj.Foo, "bar")
+ assert.Equal(t, obj.Bar, "foo")
+}
+
+func TestValidationFails(t *testing.T) {
+ var obj FooStruct
+ req := requestWithBody("POST", "/", `{"bar": "foo"}`)
+ err := JSON.Bind(req, &obj)
+ assert.Error(t, err)
+}
+
+func TestValidationDisabled(t *testing.T) {
+ backup := Validator
+ Validator = nil
+ defer func() { Validator = backup }()
+
+ var obj FooStruct
+ req := requestWithBody("POST", "/", `{"bar": "foo"}`)
+ err := JSON.Bind(req, &obj)
+ assert.NoError(t, err)
+}
+
+func TestExistsSucceeds(t *testing.T) {
+ type HogeStruct struct {
+ Hoge *int `json:"hoge" binding:"exists"`
+ }
+
+ var obj HogeStruct
+ req := requestWithBody("POST", "/", `{"hoge": 0}`)
+ err := JSON.Bind(req, &obj)
+ assert.NoError(t, err)
+}
+
+func TestFormDefaultValue(t *testing.T) {
+ params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8"
+ req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q := new(ComplexDefaultStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Equal(t, 333, q.Int)
+ assert.Equal(t, "hello", q.String)
+ assert.Equal(t, true, q.Bool)
+ assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice)
+ assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice)
+
+ params = "string=hello&bool=false"
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(ComplexDefaultStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Equal(t, 999, q.Int)
+ assert.Equal(t, "hello", q.String)
+ assert.Equal(t, false, q.Bool)
+ assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
+ assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
+
+ params = "strings=hello"
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(ComplexDefaultStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Equal(t, 999, q.Int)
+ assert.Equal(t, "default-string", q.String)
+ assert.Equal(t, false, q.Bool)
+ assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
+ assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
+
+ params = "int=&string=&bool=true&int64_slice=&int8_slice="
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ q = new(ComplexDefaultStruct)
+ assert.NoError(t, Form.Bind(req, q))
+ assert.Equal(t, 999, q.Int)
+ assert.Equal(t, "default-string", q.String)
+ assert.Equal(t, true, q.Bool)
+ assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
+ assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
+}
+
+func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) {
+ b := Form
+ assert.Equal(t, b.Name(), "form")
+
+ obj := FooBarStruct{}
+ req := requestWithBody(method, path, body)
+ if method == "POST" {
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ }
+ err := b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.Equal(t, obj.Foo, "bar")
+ assert.Equal(t, obj.Bar, "foo")
+
+ obj = FooBarStruct{}
+ req = requestWithBody(method, badPath, badBody)
+ err = JSON.Bind(req, &obj)
+ assert.Error(t, err)
+}
+
+func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) {
+ b := Query
+ assert.Equal(t, b.Name(), "query")
+
+ obj := FooBarStruct{}
+ req := requestWithBody(method, path, body)
+ if method == "POST" {
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ }
+ err := b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.Equal(t, obj.Foo, "bar")
+ assert.Equal(t, obj.Bar, "foo")
+}
+
+func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
+ assert.Equal(t, b.Name(), name)
+
+ obj := FooStruct{}
+ req := requestWithBody("POST", path, body)
+ err := b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.Equal(t, obj.Foo, "bar")
+
+ obj = FooStruct{}
+ req = requestWithBody("POST", badPath, badBody)
+ err = JSON.Bind(req, &obj)
+ assert.Error(t, err)
+}
+
+func requestWithBody(method, path, body string) (req *http.Request) {
+ req, _ = http.NewRequest(method, path, bytes.NewBufferString(body))
+ return
+}
+func BenchmarkBindingForm(b *testing.B) {
+ req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w")
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ f := Form
+ for i := 0; i < b.N; i++ {
+ obj := FooBarStruct{}
+ f.Bind(req, &obj)
+ }
+}
diff --git a/pkg/net/http/blademaster/binding/default_validator.go b/pkg/net/http/blademaster/binding/default_validator.go
new file mode 100644
index 000000000..5dbf67ed2
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/default_validator.go
@@ -0,0 +1,45 @@
+package binding
+
+import (
+ "reflect"
+ "sync"
+
+ "gopkg.in/go-playground/validator.v9"
+)
+
+type defaultValidator struct {
+ once sync.Once
+ validate *validator.Validate
+}
+
+var _ StructValidator = &defaultValidator{}
+
+func (v *defaultValidator) ValidateStruct(obj interface{}) error {
+ if kindOfData(obj) == reflect.Struct {
+ v.lazyinit()
+ if err := v.validate.Struct(obj); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
+ v.lazyinit()
+ return v.validate.RegisterValidation(key, fn)
+}
+
+func (v *defaultValidator) lazyinit() {
+ v.once.Do(func() {
+ v.validate = validator.New()
+ })
+}
+
+func kindOfData(data interface{}) reflect.Kind {
+ value := reflect.ValueOf(data)
+ valueType := value.Kind()
+ if valueType == reflect.Ptr {
+ valueType = value.Elem().Kind()
+ }
+ return valueType
+}
diff --git a/pkg/net/http/blademaster/binding/example/test.pb.go b/pkg/net/http/blademaster/binding/example/test.pb.go
new file mode 100644
index 000000000..3de8444ff
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/example/test.pb.go
@@ -0,0 +1,113 @@
+// Code generated by protoc-gen-go.
+// source: test.proto
+// DO NOT EDIT!
+
+/*
+Package example is a generated protocol buffer package.
+
+It is generated from these files:
+ test.proto
+
+It has these top-level messages:
+ Test
+*/
+package example
+
+import proto "github.com/golang/protobuf/proto"
+import math "math"
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = math.Inf
+
+type FOO int32
+
+const (
+ FOO_X FOO = 17
+)
+
+var FOO_name = map[int32]string{
+ 17: "X",
+}
+var FOO_value = map[string]int32{
+ "X": 17,
+}
+
+func (x FOO) Enum() *FOO {
+ p := new(FOO)
+ *p = x
+ return p
+}
+func (x FOO) String() string {
+ return proto.EnumName(FOO_name, int32(x))
+}
+func (x *FOO) UnmarshalJSON(data []byte) error {
+ value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO")
+ if err != nil {
+ return err
+ }
+ *x = FOO(value)
+ return nil
+}
+
+type Test struct {
+ Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
+ Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
+ Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
+ Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Test) Reset() { *m = Test{} }
+func (m *Test) String() string { return proto.CompactTextString(m) }
+func (*Test) ProtoMessage() {}
+
+const Default_Test_Type int32 = 77
+
+func (m *Test) GetLabel() string {
+ if m != nil && m.Label != nil {
+ return *m.Label
+ }
+ return ""
+}
+
+func (m *Test) GetType() int32 {
+ if m != nil && m.Type != nil {
+ return *m.Type
+ }
+ return Default_Test_Type
+}
+
+func (m *Test) GetReps() []int64 {
+ if m != nil {
+ return m.Reps
+ }
+ return nil
+}
+
+func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
+ if m != nil {
+ return m.Optionalgroup
+ }
+ return nil
+}
+
+type Test_OptionalGroup struct {
+ RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
+func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
+func (*Test_OptionalGroup) ProtoMessage() {}
+
+func (m *Test_OptionalGroup) GetRequiredField() string {
+ if m != nil && m.RequiredField != nil {
+ return *m.RequiredField
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
+}
diff --git a/pkg/net/http/blademaster/binding/example/test.proto b/pkg/net/http/blademaster/binding/example/test.proto
new file mode 100644
index 000000000..8ee9800aa
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/example/test.proto
@@ -0,0 +1,12 @@
+package example;
+
+enum FOO {X=17;};
+
+message Test {
+ required string label = 1;
+ optional int32 type = 2[default=77];
+ repeated int64 reps = 3;
+ optional group OptionalGroup = 4{
+ required string RequiredField = 5;
+ }
+}
diff --git a/pkg/net/http/blademaster/binding/example_test.go b/pkg/net/http/blademaster/binding/example_test.go
new file mode 100644
index 000000000..c667a517c
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/example_test.go
@@ -0,0 +1,36 @@
+package binding
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+)
+
+type Arg struct {
+ Max int64 `form:"max" validate:"max=10"`
+ Min int64 `form:"min" validate:"min=2"`
+ Range int64 `form:"range" validate:"min=1,max=10"`
+ // use split option to split arg 1,2,3 into slice [1 2 3]
+ // otherwise slice type with parse url.Values (eg:a=b&a=c) default.
+ Slice []int64 `form:"slice,split" validate:"min=1"`
+}
+
+func ExampleBinding() {
+ req := initHTTP("max=9&min=3&range=3&slice=1,2,3")
+ arg := new(Arg)
+ if err := Form.Bind(req, arg); err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice)
+ // Output:
+ // arg.Max 9
+ // arg.Min 3
+ // arg.Range 3
+ // arg.Slice [1 2 3]
+}
+
+func initHTTP(params string) (req *http.Request) {
+ req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
+ req.ParseForm()
+ return
+}
diff --git a/pkg/net/http/blademaster/binding/form.go b/pkg/net/http/blademaster/binding/form.go
new file mode 100644
index 000000000..61aa5ee83
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/form.go
@@ -0,0 +1,55 @@
+package binding
+
+import (
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+const defaultMemory = 32 * 1024 * 1024
+
+type formBinding struct{}
+type formPostBinding struct{}
+type formMultipartBinding struct{}
+
+func (f formBinding) Name() string {
+ return "form"
+}
+
+func (f formBinding) Bind(req *http.Request, obj interface{}) error {
+ if err := req.ParseForm(); err != nil {
+ return errors.WithStack(err)
+ }
+ if err := mapForm(obj, req.Form); err != nil {
+ return err
+ }
+ return validate(obj)
+}
+
+func (f formPostBinding) Name() string {
+ return "form-urlencoded"
+}
+
+func (f formPostBinding) Bind(req *http.Request, obj interface{}) error {
+ if err := req.ParseForm(); err != nil {
+ return errors.WithStack(err)
+ }
+ if err := mapForm(obj, req.PostForm); err != nil {
+ return err
+ }
+ return validate(obj)
+}
+
+func (f formMultipartBinding) Name() string {
+ return "multipart/form-data"
+}
+
+func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
+ if err := req.ParseMultipartForm(defaultMemory); err != nil {
+ return errors.WithStack(err)
+ }
+ if err := mapForm(obj, req.MultipartForm.Value); err != nil {
+ return err
+ }
+ return validate(obj)
+}
diff --git a/pkg/net/http/blademaster/binding/form_mapping.go b/pkg/net/http/blademaster/binding/form_mapping.go
new file mode 100644
index 000000000..ac4ecd116
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/form_mapping.go
@@ -0,0 +1,276 @@
+package binding
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/pkg/errors"
+)
+
+// scache struct reflect type cache.
+var scache = &cache{
+ data: make(map[reflect.Type]*sinfo),
+}
+
+type cache struct {
+ data map[reflect.Type]*sinfo
+ mutex sync.RWMutex
+}
+
+func (c *cache) get(obj reflect.Type) (s *sinfo) {
+ var ok bool
+ c.mutex.RLock()
+ if s, ok = c.data[obj]; !ok {
+ c.mutex.RUnlock()
+ s = c.set(obj)
+ return
+ }
+ c.mutex.RUnlock()
+ return
+}
+
+func (c *cache) set(obj reflect.Type) (s *sinfo) {
+ s = new(sinfo)
+ tp := obj.Elem()
+ for i := 0; i < tp.NumField(); i++ {
+ fd := new(field)
+ fd.tp = tp.Field(i)
+ tag := fd.tp.Tag.Get("form")
+ fd.name, fd.option = parseTag(tag)
+ if defV := fd.tp.Tag.Get("default"); defV != "" {
+ dv := reflect.New(fd.tp.Type).Elem()
+ setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option)
+ fd.hasDefault = true
+ fd.defaultValue = dv
+ }
+ s.field = append(s.field, fd)
+ }
+ c.mutex.Lock()
+ c.data[obj] = s
+ c.mutex.Unlock()
+ return
+}
+
+type sinfo struct {
+ field []*field
+}
+
+type field struct {
+ tp reflect.StructField
+ name string
+ option tagOptions
+
+ hasDefault bool // if field had default value
+ defaultValue reflect.Value // field default value
+}
+
+func mapForm(ptr interface{}, form map[string][]string) error {
+ sinfo := scache.get(reflect.TypeOf(ptr))
+ val := reflect.ValueOf(ptr).Elem()
+ for i, fd := range sinfo.field {
+ typeField := fd.tp
+ structField := val.Field(i)
+ if !structField.CanSet() {
+ continue
+ }
+
+ structFieldKind := structField.Kind()
+ inputFieldName := fd.name
+ if inputFieldName == "" {
+ inputFieldName = typeField.Name
+
+ // if "form" tag is nil, we inspect if the field is a struct.
+ // this would not make sense for JSON parsing but it does for a form
+ // since data is flatten
+ if structFieldKind == reflect.Struct {
+ err := mapForm(structField.Addr().Interface(), form)
+ if err != nil {
+ return err
+ }
+ continue
+ }
+ }
+ inputValue, exists := form[inputFieldName]
+ if !exists {
+ // Set the field as default value when the input value is not exist
+ if fd.hasDefault {
+ structField.Set(fd.defaultValue)
+ }
+ continue
+ }
+ // Set the field as default value when the input value is empty
+ if fd.hasDefault && inputValue[0] == "" {
+ structField.Set(fd.defaultValue)
+ continue
+ }
+ if _, isTime := structField.Interface().(time.Time); isTime {
+ if err := setTimeField(inputValue[0], typeField, structField); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error {
+ switch valueKind {
+ case reflect.Int:
+ return setIntField(val[0], 0, structField)
+ case reflect.Int8:
+ return setIntField(val[0], 8, structField)
+ case reflect.Int16:
+ return setIntField(val[0], 16, structField)
+ case reflect.Int32:
+ return setIntField(val[0], 32, structField)
+ case reflect.Int64:
+ return setIntField(val[0], 64, structField)
+ case reflect.Uint:
+ return setUintField(val[0], 0, structField)
+ case reflect.Uint8:
+ return setUintField(val[0], 8, structField)
+ case reflect.Uint16:
+ return setUintField(val[0], 16, structField)
+ case reflect.Uint32:
+ return setUintField(val[0], 32, structField)
+ case reflect.Uint64:
+ return setUintField(val[0], 64, structField)
+ case reflect.Bool:
+ return setBoolField(val[0], structField)
+ case reflect.Float32:
+ return setFloatField(val[0], 32, structField)
+ case reflect.Float64:
+ return setFloatField(val[0], 64, structField)
+ case reflect.String:
+ structField.SetString(val[0])
+ case reflect.Slice:
+ if option.Contains("split") {
+ val = strings.Split(val[0], ",")
+ }
+ filtered := filterEmpty(val)
+ switch structField.Type().Elem().Kind() {
+ case reflect.Int64:
+ valSli := make([]int64, 0, len(filtered))
+ for i := 0; i < len(filtered); i++ {
+ d, err := strconv.ParseInt(filtered[i], 10, 64)
+ if err != nil {
+ return err
+ }
+ valSli = append(valSli, d)
+ }
+ structField.Set(reflect.ValueOf(valSli))
+ case reflect.String:
+ valSli := make([]string, 0, len(filtered))
+ for i := 0; i < len(filtered); i++ {
+ valSli = append(valSli, filtered[i])
+ }
+ structField.Set(reflect.ValueOf(valSli))
+ default:
+ sliceOf := structField.Type().Elem().Kind()
+ numElems := len(filtered)
+ slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered))
+ for i := 0; i < numElems; i++ {
+ if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil {
+ return err
+ }
+ }
+ structField.Set(slice)
+ }
+ default:
+ return errors.New("Unknown type")
+ }
+ return nil
+}
+
+func setIntField(val string, bitSize int, field reflect.Value) error {
+ if val == "" {
+ val = "0"
+ }
+ intVal, err := strconv.ParseInt(val, 10, bitSize)
+ if err == nil {
+ field.SetInt(intVal)
+ }
+ return errors.WithStack(err)
+}
+
+func setUintField(val string, bitSize int, field reflect.Value) error {
+ if val == "" {
+ val = "0"
+ }
+ uintVal, err := strconv.ParseUint(val, 10, bitSize)
+ if err == nil {
+ field.SetUint(uintVal)
+ }
+ return errors.WithStack(err)
+}
+
+func setBoolField(val string, field reflect.Value) error {
+ if val == "" {
+ val = "false"
+ }
+ boolVal, err := strconv.ParseBool(val)
+ if err == nil {
+ field.SetBool(boolVal)
+ }
+ return nil
+}
+
+func setFloatField(val string, bitSize int, field reflect.Value) error {
+ if val == "" {
+ val = "0.0"
+ }
+ floatVal, err := strconv.ParseFloat(val, bitSize)
+ if err == nil {
+ field.SetFloat(floatVal)
+ }
+ return errors.WithStack(err)
+}
+
+func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
+ timeFormat := structField.Tag.Get("time_format")
+ if timeFormat == "" {
+ return errors.New("Blank time format")
+ }
+
+ if val == "" {
+ value.Set(reflect.ValueOf(time.Time{}))
+ return nil
+ }
+
+ l := time.Local
+ if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
+ l = time.UTC
+ }
+
+ if locTag := structField.Tag.Get("time_location"); locTag != "" {
+ loc, err := time.LoadLocation(locTag)
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ l = loc
+ }
+
+ t, err := time.ParseInLocation(timeFormat, val, l)
+ if err != nil {
+ return errors.WithStack(err)
+ }
+
+ value.Set(reflect.ValueOf(t))
+ return nil
+}
+
+func filterEmpty(val []string) []string {
+ filtered := make([]string, 0, len(val))
+ for _, v := range val {
+ if v != "" {
+ filtered = append(filtered, v)
+ }
+ }
+ return filtered
+}
diff --git a/pkg/net/http/blademaster/binding/json.go b/pkg/net/http/blademaster/binding/json.go
new file mode 100644
index 000000000..f01e479b3
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/json.go
@@ -0,0 +1,22 @@
+package binding
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+type jsonBinding struct{}
+
+func (jsonBinding) Name() string {
+ return "json"
+}
+
+func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
+ decoder := json.NewDecoder(req.Body)
+ if err := decoder.Decode(obj); err != nil {
+ return errors.WithStack(err)
+ }
+ return validate(obj)
+}
diff --git a/pkg/net/http/blademaster/binding/query.go b/pkg/net/http/blademaster/binding/query.go
new file mode 100644
index 000000000..b169436eb
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/query.go
@@ -0,0 +1,19 @@
+package binding
+
+import (
+ "net/http"
+)
+
+type queryBinding struct{}
+
+func (queryBinding) Name() string {
+ return "query"
+}
+
+func (queryBinding) Bind(req *http.Request, obj interface{}) error {
+ values := req.URL.Query()
+ if err := mapForm(obj, values); err != nil {
+ return err
+ }
+ return validate(obj)
+}
diff --git a/pkg/net/http/blademaster/binding/tags.go b/pkg/net/http/blademaster/binding/tags.go
new file mode 100644
index 000000000..535bd8624
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/tags.go
@@ -0,0 +1,44 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+ "strings"
+)
+
+// tagOptions is the string following a comma in a struct field's "json"
+// tag, or the empty string. It does not include the leading comma.
+type tagOptions string
+
+// parseTag splits a struct field's json tag into its name and
+// comma-separated options.
+func parseTag(tag string) (string, tagOptions) {
+ if idx := strings.Index(tag, ","); idx != -1 {
+ return tag[:idx], tagOptions(tag[idx+1:])
+ }
+ return tag, tagOptions("")
+}
+
+// Contains reports whether a comma-separated list of options
+// contains a particular substr flag. substr must be surrounded by a
+// string boundary or commas.
+func (o tagOptions) Contains(optionName string) bool {
+ if len(o) == 0 {
+ return false
+ }
+ s := string(o)
+ for s != "" {
+ var next string
+ i := strings.Index(s, ",")
+ if i >= 0 {
+ s, next = s[:i], s[i+1:]
+ }
+ if s == optionName {
+ return true
+ }
+ s = next
+ }
+ return false
+}
diff --git a/pkg/net/http/blademaster/binding/validate_test.go b/pkg/net/http/blademaster/binding/validate_test.go
new file mode 100644
index 000000000..ac1793713
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/validate_test.go
@@ -0,0 +1,209 @@
+package binding
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type testInterface interface {
+ String() string
+}
+
+type substructNoValidation struct {
+ IString string
+ IInt int
+}
+
+type mapNoValidationSub map[string]substructNoValidation
+
+type structNoValidationValues struct {
+ substructNoValidation
+
+ Boolean bool
+
+ Uinteger uint
+ Integer int
+ Integer8 int8
+ Integer16 int16
+ Integer32 int32
+ Integer64 int64
+ Uinteger8 uint8
+ Uinteger16 uint16
+ Uinteger32 uint32
+ Uinteger64 uint64
+
+ Float32 float32
+ Float64 float64
+
+ String string
+
+ Date time.Time
+
+ Struct substructNoValidation
+ InlinedStruct struct {
+ String []string
+ Integer int
+ }
+
+ IntSlice []int
+ IntPointerSlice []*int
+ StructPointerSlice []*substructNoValidation
+ StructSlice []substructNoValidation
+ InterfaceSlice []testInterface
+
+ UniversalInterface interface{}
+ CustomInterface testInterface
+
+ FloatMap map[string]float32
+ StructMap mapNoValidationSub
+}
+
+func createNoValidationValues() structNoValidationValues {
+ integer := 1
+ s := structNoValidationValues{
+ Boolean: true,
+ Uinteger: 1 << 29,
+ Integer: -10000,
+ Integer8: 120,
+ Integer16: -20000,
+ Integer32: 1 << 29,
+ Integer64: 1 << 61,
+ Uinteger8: 250,
+ Uinteger16: 50000,
+ Uinteger32: 1 << 31,
+ Uinteger64: 1 << 62,
+ Float32: 123.456,
+ Float64: 123.456789,
+ String: "text",
+ Date: time.Time{},
+ CustomInterface: &bytes.Buffer{},
+ Struct: substructNoValidation{},
+ IntSlice: []int{-3, -2, 1, 0, 1, 2, 3},
+ IntPointerSlice: []*int{&integer},
+ StructSlice: []substructNoValidation{},
+ UniversalInterface: 1.2,
+ FloatMap: map[string]float32{
+ "foo": 1.23,
+ "bar": 232.323,
+ },
+ StructMap: mapNoValidationSub{
+ "foo": substructNoValidation{},
+ "bar": substructNoValidation{},
+ },
+ // StructPointerSlice []noValidationSub
+ // InterfaceSlice []testInterface
+ }
+ s.InlinedStruct.Integer = 1000
+ s.InlinedStruct.String = []string{"first", "second"}
+ s.IString = "substring"
+ s.IInt = 987654
+ return s
+}
+
+func TestValidateNoValidationValues(t *testing.T) {
+ origin := createNoValidationValues()
+ test := createNoValidationValues()
+ empty := structNoValidationValues{}
+
+ assert.Nil(t, validate(test))
+ assert.Nil(t, validate(&test))
+ assert.Nil(t, validate(empty))
+ assert.Nil(t, validate(&empty))
+
+ assert.Equal(t, origin, test)
+}
+
+type structNoValidationPointer struct {
+ // substructNoValidation
+
+ Boolean bool
+
+ Uinteger *uint
+ Integer *int
+ Integer8 *int8
+ Integer16 *int16
+ Integer32 *int32
+ Integer64 *int64
+ Uinteger8 *uint8
+ Uinteger16 *uint16
+ Uinteger32 *uint32
+ Uinteger64 *uint64
+
+ Float32 *float32
+ Float64 *float64
+
+ String *string
+
+ Date *time.Time
+
+ Struct *substructNoValidation
+
+ IntSlice *[]int
+ IntPointerSlice *[]*int
+ StructPointerSlice *[]*substructNoValidation
+ StructSlice *[]substructNoValidation
+ InterfaceSlice *[]testInterface
+
+ FloatMap *map[string]float32
+ StructMap *mapNoValidationSub
+}
+
+func TestValidateNoValidationPointers(t *testing.T) {
+ //origin := createNoValidation_values()
+ //test := createNoValidation_values()
+ empty := structNoValidationPointer{}
+
+ //assert.Nil(t, validate(test))
+ //assert.Nil(t, validate(&test))
+ assert.Nil(t, validate(empty))
+ assert.Nil(t, validate(&empty))
+
+ //assert.Equal(t, origin, test)
+}
+
+type Object map[string]interface{}
+
+func TestValidatePrimitives(t *testing.T) {
+ obj := Object{"foo": "bar", "bar": 1}
+ assert.NoError(t, validate(obj))
+ assert.NoError(t, validate(&obj))
+ assert.Equal(t, obj, Object{"foo": "bar", "bar": 1})
+
+ obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}}
+ assert.NoError(t, validate(obj2))
+ assert.NoError(t, validate(&obj2))
+
+ nu := 10
+ assert.NoError(t, validate(nu))
+ assert.NoError(t, validate(&nu))
+ assert.Equal(t, nu, 10)
+
+ str := "value"
+ assert.NoError(t, validate(str))
+ assert.NoError(t, validate(&str))
+ assert.Equal(t, str, "value")
+}
+
+// structCustomValidation is a helper struct we use to check that
+// custom validation can be registered on it.
+// The `notone` binding directive is for custom validation and registered later.
+// type structCustomValidation struct {
+// Integer int `binding:"notone"`
+// }
+
+// notOne is a custom validator meant to be used with `validator.v8` library.
+// The method signature for `v9` is significantly different and this function
+// would need to be changed for tests to pass after upgrade.
+// See https://github.com/gin-gonic/gin/pull/1015.
+// func notOne(
+// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
+// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
+// ) bool {
+// if val, ok := field.Interface().(int); ok {
+// return val != 1
+// }
+// return false
+// }
diff --git a/pkg/net/http/blademaster/binding/xml.go b/pkg/net/http/blademaster/binding/xml.go
new file mode 100644
index 000000000..99b303c2f
--- /dev/null
+++ b/pkg/net/http/blademaster/binding/xml.go
@@ -0,0 +1,22 @@
+package binding
+
+import (
+ "encoding/xml"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+type xmlBinding struct{}
+
+func (xmlBinding) Name() string {
+ return "xml"
+}
+
+func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
+ decoder := xml.NewDecoder(req.Body)
+ if err := decoder.Decode(obj); err != nil {
+ return errors.WithStack(err)
+ }
+ return validate(obj)
+}
diff --git a/pkg/net/http/blademaster/context.go b/pkg/net/http/blademaster/context.go
new file mode 100644
index 000000000..59154435c
--- /dev/null
+++ b/pkg/net/http/blademaster/context.go
@@ -0,0 +1,306 @@
+package blademaster
+
+import (
+ "context"
+ "math"
+ "net/http"
+ "strconv"
+
+ "github.com/bilibili/Kratos/pkg/ecode"
+ "github.com/bilibili/Kratos/pkg/net/http/blademaster/binding"
+ "github.com/bilibili/Kratos/pkg/net/http/blademaster/render"
+
+ "github.com/gogo/protobuf/proto"
+ "github.com/gogo/protobuf/types"
+ "github.com/pkg/errors"
+)
+
+const (
+ _abortIndex int8 = math.MaxInt8 / 2
+)
+
+var (
+ _openParen = []byte("(")
+ _closeParen = []byte(")")
+)
+
+// Context is the most important part. It allows us to pass variables between
+// middleware, manage the flow, validate the JSON of a request and render a
+// JSON response for example.
+type Context struct {
+ context.Context
+
+ Request *http.Request
+ Writer http.ResponseWriter
+
+ // flow control
+ index int8
+ handlers []HandlerFunc
+
+ // Keys is a key/value pair exclusively for the context of each request.
+ Keys map[string]interface{}
+
+ Error error
+
+ method string
+ engine *Engine
+}
+
+/************************************/
+/*********** FLOW CONTROL ***********/
+/************************************/
+
+// Next should be used only inside middleware.
+// It executes the pending handlers in the chain inside the calling handler.
+// See example in godoc.
+func (c *Context) Next() {
+ c.index++
+ s := int8(len(c.handlers))
+ for ; c.index < s; c.index++ {
+ // only check method on last handler, otherwise middlewares
+ // will never be effected if request method is not matched
+ if c.index == s-1 && c.method != c.Request.Method {
+ code := http.StatusMethodNotAllowed
+ c.Error = ecode.MethodNotAllowed
+ http.Error(c.Writer, http.StatusText(code), code)
+ return
+ }
+
+ c.handlers[c.index](c)
+ }
+}
+
+// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
+// Let's say you have an authorization middleware that validates that the current request is authorized.
+// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
+// for this request are not called.
+func (c *Context) Abort() {
+ c.index = _abortIndex
+}
+
+// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
+// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
+func (c *Context) AbortWithStatus(code int) {
+ c.Status(code)
+ c.Abort()
+}
+
+// IsAborted returns true if the current context was aborted.
+func (c *Context) IsAborted() bool {
+ return c.index >= _abortIndex
+}
+
+/************************************/
+/******** METADATA MANAGEMENT********/
+/************************************/
+
+// Set is used to store a new key/value pair exclusively for this context.
+// It also lazy initializes c.Keys if it was not used previously.
+func (c *Context) Set(key string, value interface{}) {
+ if c.Keys == nil {
+ c.Keys = make(map[string]interface{})
+ }
+ c.Keys[key] = value
+}
+
+// Get returns the value for the given key, ie: (value, true).
+// If the value does not exists it returns (nil, false)
+func (c *Context) Get(key string) (value interface{}, exists bool) {
+ value, exists = c.Keys[key]
+ return
+}
+
+/************************************/
+/******** RESPONSE RENDERING ********/
+/************************************/
+
+// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
+func bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+// Status sets the HTTP response code.
+func (c *Context) Status(code int) {
+ c.Writer.WriteHeader(code)
+}
+
+// Render http response with http code by a render instance.
+func (c *Context) Render(code int, r render.Render) {
+ r.WriteContentType(c.Writer)
+ if code > 0 {
+ c.Status(code)
+ }
+
+ if !bodyAllowedForStatus(code) {
+ return
+ }
+
+ params := c.Request.Form
+
+ cb := params.Get("callback")
+ jsonp := cb != "" && params.Get("jsonp") == "jsonp"
+ if jsonp {
+ c.Writer.Write([]byte(cb))
+ c.Writer.Write(_openParen)
+ }
+
+ if err := r.Render(c.Writer); err != nil {
+ c.Error = err
+ return
+ }
+
+ if jsonp {
+ if _, err := c.Writer.Write(_closeParen); err != nil {
+ c.Error = errors.WithStack(err)
+ }
+ }
+}
+
+// JSON serializes the given struct as JSON into the response body.
+// It also sets the Content-Type as "application/json".
+func (c *Context) JSON(data interface{}, err error) {
+ code := http.StatusOK
+ c.Error = err
+ bcode := ecode.Cause(err)
+ // TODO app allow 5xx?
+ /*
+ if bcode.Code() == -500 {
+ code = http.StatusServiceUnavailable
+ }
+ */
+ writeStatusCode(c.Writer, bcode.Code())
+ c.Render(code, render.JSON{
+ Code: bcode.Code(),
+ Message: bcode.Message(),
+ Data: data,
+ })
+}
+
+// JSONMap serializes the given map as map JSON into the response body.
+// It also sets the Content-Type as "application/json".
+func (c *Context) JSONMap(data map[string]interface{}, err error) {
+ code := http.StatusOK
+ c.Error = err
+ bcode := ecode.Cause(err)
+ // TODO app allow 5xx?
+ /*
+ if bcode.Code() == -500 {
+ code = http.StatusServiceUnavailable
+ }
+ */
+ writeStatusCode(c.Writer, bcode.Code())
+ data["code"] = bcode.Code()
+ if _, ok := data["message"]; !ok {
+ data["message"] = bcode.Message()
+ }
+ c.Render(code, render.MapJSON(data))
+}
+
+// XML serializes the given struct as XML into the response body.
+// It also sets the Content-Type as "application/xml".
+func (c *Context) XML(data interface{}, err error) {
+ code := http.StatusOK
+ c.Error = err
+ bcode := ecode.Cause(err)
+ // TODO app allow 5xx?
+ /*
+ if bcode.Code() == -500 {
+ code = http.StatusServiceUnavailable
+ }
+ */
+ writeStatusCode(c.Writer, bcode.Code())
+ c.Render(code, render.XML{
+ Code: bcode.Code(),
+ Message: bcode.Message(),
+ Data: data,
+ })
+}
+
+// Protobuf serializes the given struct as PB into the response body.
+// It also sets the ContentType as "application/x-protobuf".
+func (c *Context) Protobuf(data proto.Message, err error) {
+ var (
+ bytes []byte
+ )
+
+ code := http.StatusOK
+ c.Error = err
+ bcode := ecode.Cause(err)
+
+ any := new(types.Any)
+ if data != nil {
+ if bytes, err = proto.Marshal(data); err != nil {
+ c.Error = errors.WithStack(err)
+ return
+ }
+ any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data)
+ any.Value = bytes
+ }
+ writeStatusCode(c.Writer, bcode.Code())
+ c.Render(code, render.PB{
+ Code: int64(bcode.Code()),
+ Message: bcode.Message(),
+ Data: any,
+ })
+}
+
+// Bytes writes some data into the body stream and updates the HTTP code.
+func (c *Context) Bytes(code int, contentType string, data ...[]byte) {
+ c.Render(code, render.Data{
+ ContentType: contentType,
+ Data: data,
+ })
+}
+
+// String writes the given string into the response body.
+func (c *Context) String(code int, format string, values ...interface{}) {
+ c.Render(code, render.String{Format: format, Data: values})
+}
+
+// Redirect returns a HTTP redirect to the specific location.
+func (c *Context) Redirect(code int, location string) {
+ c.Render(-1, render.Redirect{
+ Code: code,
+ Location: location,
+ Request: c.Request,
+ })
+}
+
+// BindWith bind req arg with parser.
+func (c *Context) BindWith(obj interface{}, b binding.Binding) error {
+ return c.mustBindWith(obj, b)
+}
+
+// Bind bind req arg with defult form binding.
+func (c *Context) Bind(obj interface{}) error {
+ return c.mustBindWith(obj, binding.Form)
+}
+
+// mustBindWith binds the passed struct pointer using the specified binding engine.
+// It will abort the request with HTTP 400 if any error ocurrs.
+// See the binding package.
+func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) {
+ if err = b.Bind(c.Request, obj); err != nil {
+ c.Error = ecode.RequestErr
+ c.Render(http.StatusOK, render.JSON{
+ Code: ecode.RequestErr.Code(),
+ Message: err.Error(),
+ Data: nil,
+ })
+ c.Abort()
+ }
+ return
+}
+
+func writeStatusCode(w http.ResponseWriter, ecode int) {
+ header := w.Header()
+ header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
+}
diff --git a/pkg/net/http/blademaster/cors.go b/pkg/net/http/blademaster/cors.go
new file mode 100644
index 000000000..f4e88a1db
--- /dev/null
+++ b/pkg/net/http/blademaster/cors.go
@@ -0,0 +1,249 @@
+package blademaster
+
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/bilibili/Kratos/pkg/log"
+
+ "github.com/pkg/errors"
+)
+
+// CORSConfig represents all available options for the middleware.
+type CORSConfig struct {
+ AllowAllOrigins bool
+
+ // AllowedOrigins is a list of origins a cross-domain request can be executed from.
+ // If the special "*" value is present in the list, all origins will be allowed.
+ // Default value is []
+ AllowOrigins []string
+
+ // AllowOriginFunc is a custom function to validate the origin. It take the origin
+ // as argument and returns true if allowed or false otherwise. If this option is
+ // set, the content of AllowedOrigins is ignored.
+ AllowOriginFunc func(origin string) bool
+
+ // AllowedMethods is a list of methods the client is allowed to use with
+ // cross-domain requests. Default value is simple methods (GET and POST)
+ AllowMethods []string
+
+ // AllowedHeaders is list of non simple headers the client is allowed to use with
+ // cross-domain requests.
+ AllowHeaders []string
+
+ // AllowCredentials indicates whether the request can include user credentials like
+ // cookies, HTTP authentication or client side SSL certificates.
+ AllowCredentials bool
+
+ // ExposedHeaders indicates which headers are safe to expose to the API of a CORS
+ // API specification
+ ExposeHeaders []string
+
+ // MaxAge indicates how long (in seconds) the results of a preflight request
+ // can be cached
+ MaxAge time.Duration
+}
+
+type cors struct {
+ allowAllOrigins bool
+ allowCredentials bool
+ allowOriginFunc func(string) bool
+ allowOrigins []string
+ normalHeaders http.Header
+ preflightHeaders http.Header
+}
+
+type converter func(string) string
+
+// Validate is check configuration of user defined.
+func (c *CORSConfig) Validate() error {
+ if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
+ return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
+ }
+ if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
+ return errors.New("conflict settings: all origins disabled")
+ }
+ for _, origin := range c.AllowOrigins {
+ if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
+ return errors.New("bad origin: origins must either be '*' or include http:// or https://")
+ }
+ }
+ return nil
+}
+
+// CORS returns the location middleware with default configuration.
+func CORS(allowOriginHosts []string) HandlerFunc {
+ config := &CORSConfig{
+ AllowMethods: []string{"GET", "POST"},
+ AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
+ AllowCredentials: true,
+ MaxAge: time.Duration(0),
+ AllowOriginFunc: func(origin string) bool {
+ for _, host := range allowOriginHosts {
+ if strings.HasSuffix(strings.ToLower(origin), host) {
+ return true
+ }
+ }
+ return false
+ },
+ }
+ return newCORS(config)
+}
+
+// newCORS returns the location middleware with user-defined custom configuration.
+func newCORS(config *CORSConfig) HandlerFunc {
+ if err := config.Validate(); err != nil {
+ panic(err.Error())
+ }
+ cors := &cors{
+ allowOriginFunc: config.AllowOriginFunc,
+ allowAllOrigins: config.AllowAllOrigins,
+ allowCredentials: config.AllowCredentials,
+ allowOrigins: normalize(config.AllowOrigins),
+ normalHeaders: generateNormalHeaders(config),
+ preflightHeaders: generatePreflightHeaders(config),
+ }
+
+ return func(c *Context) {
+ cors.applyCORS(c)
+ }
+}
+
+func (cors *cors) applyCORS(c *Context) {
+ origin := c.Request.Header.Get("Origin")
+ if len(origin) == 0 {
+ // request is not a CORS request
+ return
+ }
+ if !cors.validateOrigin(origin) {
+ log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
+ c.AbortWithStatus(http.StatusForbidden)
+ return
+ }
+
+ if c.Request.Method == "OPTIONS" {
+ cors.handlePreflight(c)
+ defer c.AbortWithStatus(200)
+ } else {
+ cors.handleNormal(c)
+ }
+
+ if !cors.allowAllOrigins {
+ header := c.Writer.Header()
+ header.Set("Access-Control-Allow-Origin", origin)
+ }
+}
+
+func (cors *cors) validateOrigin(origin string) bool {
+ if cors.allowAllOrigins {
+ return true
+ }
+ for _, value := range cors.allowOrigins {
+ if value == origin {
+ return true
+ }
+ }
+ if cors.allowOriginFunc != nil {
+ return cors.allowOriginFunc(origin)
+ }
+ return false
+}
+
+func (cors *cors) handlePreflight(c *Context) {
+ header := c.Writer.Header()
+ for key, value := range cors.preflightHeaders {
+ header[key] = value
+ }
+}
+
+func (cors *cors) handleNormal(c *Context) {
+ header := c.Writer.Header()
+ for key, value := range cors.normalHeaders {
+ header[key] = value
+ }
+}
+
+func generateNormalHeaders(c *CORSConfig) http.Header {
+ headers := make(http.Header)
+ if c.AllowCredentials {
+ headers.Set("Access-Control-Allow-Credentials", "true")
+ }
+
+ // backport support for early browsers
+ if len(c.AllowMethods) > 0 {
+ allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
+ value := strings.Join(allowMethods, ",")
+ headers.Set("Access-Control-Allow-Methods", value)
+ }
+
+ if len(c.ExposeHeaders) > 0 {
+ exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
+ headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
+ }
+ if c.AllowAllOrigins {
+ headers.Set("Access-Control-Allow-Origin", "*")
+ } else {
+ headers.Set("Vary", "Origin")
+ }
+ return headers
+}
+
+func generatePreflightHeaders(c *CORSConfig) http.Header {
+ headers := make(http.Header)
+ if c.AllowCredentials {
+ headers.Set("Access-Control-Allow-Credentials", "true")
+ }
+ if len(c.AllowMethods) > 0 {
+ allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
+ value := strings.Join(allowMethods, ",")
+ headers.Set("Access-Control-Allow-Methods", value)
+ }
+ if len(c.AllowHeaders) > 0 {
+ allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
+ value := strings.Join(allowHeaders, ",")
+ headers.Set("Access-Control-Allow-Headers", value)
+ }
+ if c.MaxAge > time.Duration(0) {
+ value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
+ headers.Set("Access-Control-Max-Age", value)
+ }
+ if c.AllowAllOrigins {
+ headers.Set("Access-Control-Allow-Origin", "*")
+ } else {
+ // Always set Vary headers
+ // see https://github.com/rs/cors/issues/10,
+ // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
+
+ headers.Add("Vary", "Origin")
+ headers.Add("Vary", "Access-Control-Request-Method")
+ headers.Add("Vary", "Access-Control-Request-Headers")
+ }
+ return headers
+}
+
+func normalize(values []string) []string {
+ if values == nil {
+ return nil
+ }
+ distinctMap := make(map[string]bool, len(values))
+ normalized := make([]string, 0, len(values))
+ for _, value := range values {
+ value = strings.TrimSpace(value)
+ value = strings.ToLower(value)
+ if _, seen := distinctMap[value]; !seen {
+ normalized = append(normalized, value)
+ distinctMap[value] = true
+ }
+ }
+ return normalized
+}
+
+func convert(s []string, c converter) []string {
+ var out []string
+ for _, i := range s {
+ out = append(out, c(i))
+ }
+ return out
+}
diff --git a/pkg/net/http/blademaster/csrf.go b/pkg/net/http/blademaster/csrf.go
new file mode 100644
index 000000000..a45825189
--- /dev/null
+++ b/pkg/net/http/blademaster/csrf.go
@@ -0,0 +1,69 @@
+package blademaster
+
+import (
+ "net/url"
+ "regexp"
+ "strings"
+
+ "github.com/bilibili/Kratos/pkg/log"
+)
+
+func matchHostSuffix(suffix string) func(*url.URL) bool {
+ return func(uri *url.URL) bool {
+ return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
+ }
+}
+
+func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
+ return func(uri *url.URL) bool {
+ return pattern.MatchString(strings.ToLower(uri.String()))
+ }
+}
+
+// CSRF returns the csrf middleware to prevent invalid cross site request.
+// Only referer is checked currently.
+func CSRF(allowHosts []string, allowPattern []string) HandlerFunc {
+ validations := []func(*url.URL) bool{}
+
+ addHostSuffix := func(suffix string) {
+ validations = append(validations, matchHostSuffix(suffix))
+ }
+ addPattern := func(pattern string) {
+ validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
+ }
+
+ for _, r := range allowHosts {
+ addHostSuffix(r)
+ }
+ for _, p := range allowPattern {
+ addPattern(p)
+ }
+
+ return func(c *Context) {
+ referer := c.Request.Header.Get("Referer")
+ params := c.Request.Form
+ cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "")
+ if referer == "" {
+ if !cross {
+ return
+ }
+ log.V(5).Info("The request's Referer header is empty.")
+ c.AbortWithStatus(403)
+ return
+ }
+ illegal := true
+ if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
+ for _, validate := range validations {
+ if validate(uri) {
+ illegal = false
+ break
+ }
+ }
+ }
+ if illegal {
+ log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
+ c.AbortWithStatus(403)
+ return
+ }
+ }
+}
diff --git a/pkg/net/http/blademaster/logger.go b/pkg/net/http/blademaster/logger.go
new file mode 100644
index 000000000..58fb74ff1
--- /dev/null
+++ b/pkg/net/http/blademaster/logger.go
@@ -0,0 +1,69 @@
+package blademaster
+
+import (
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/bilibili/Kratos/pkg/ecode"
+ "github.com/bilibili/Kratos/pkg/log"
+ "github.com/bilibili/Kratos/pkg/net/metadata"
+)
+
+// Logger is logger middleware
+func Logger() HandlerFunc {
+ const noUser = "no_user"
+ return func(c *Context) {
+ now := time.Now()
+ ip := metadata.String(c, metadata.RemoteIP)
+ req := c.Request
+ path := req.URL.Path
+ params := req.Form
+ var quota float64
+ if deadline, ok := c.Context.Deadline(); ok {
+ quota = time.Until(deadline).Seconds()
+ }
+
+ c.Next()
+
+ err := c.Error
+ cerr := ecode.Cause(err)
+ dt := time.Since(now)
+ caller := metadata.String(c, metadata.Caller)
+ if caller == "" {
+ caller = noUser
+ }
+
+ stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10))
+ stats.Timing(caller, int64(dt/time.Millisecond), path[1:])
+
+ lf := log.Infov
+ errmsg := ""
+ isSlow := dt >= (time.Millisecond * 500)
+ if err != nil {
+ errmsg = err.Error()
+ lf = log.Errorv
+ if cerr.Code() > 0 {
+ lf = log.Warnv
+ }
+ } else {
+ if isSlow {
+ lf = log.Warnv
+ }
+ }
+ lf(c,
+ log.KVString("method", req.Method),
+ log.KVString("ip", ip),
+ log.KVString("user", caller),
+ log.KVString("path", path),
+ log.KVString("params", params.Encode()),
+ log.KVInt("ret", cerr.Code()),
+ log.KVString("msg", cerr.Message()),
+ log.KVString("stack", fmt.Sprintf("%+v", err)),
+ log.KVString("err", errmsg),
+ log.KVFloat64("timeout_quota", quota),
+ log.KVFloat64("ts", dt.Seconds()),
+ log.KVString("source", "http-access-log"),
+ )
+ }
+}
diff --git a/pkg/net/http/blademaster/metadata.go b/pkg/net/http/blademaster/metadata.go
index a868f4305..443e86f24 100644
--- a/pkg/net/http/blademaster/metadata.go
+++ b/pkg/net/http/blademaster/metadata.go
@@ -17,13 +17,14 @@ const (
_httpHeaderUser = "x1-bmspy-user"
_httpHeaderColor = "x1-bmspy-color"
_httpHeaderTimeout = "x1-bmspy-timeout"
+ _httpHeaderMirror = "x1-bmspy-mirror"
_httpHeaderRemoteIP = "x-backend-bm-real-ip"
_httpHeaderRemoteIPPort = "x-backend-bm-real-ipport"
)
-// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true.
+// mirror return true if x-bmspy-mirror in http header and its value is 1 or true.
func mirror(req *http.Request) bool {
- mirrorStr := req.Header.Get("x1-bilispy-mirror")
+ mirrorStr := req.Header.Get(_httpHeaderMirror)
if mirrorStr == "" {
return false
}
@@ -79,7 +80,7 @@ func timeout(req *http.Request) time.Duration {
}
// remoteIP implements a best effort algorithm to return the real client IP, it parses
-// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
+// x-backend-bm-real-ip or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
func remoteIP(req *http.Request) (remote string) {
if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" {
diff --git a/pkg/net/http/blademaster/perf.go b/pkg/net/http/blademaster/perf.go
new file mode 100644
index 000000000..361cee607
--- /dev/null
+++ b/pkg/net/http/blademaster/perf.go
@@ -0,0 +1,46 @@
+package blademaster
+
+import (
+ "flag"
+ "net/http"
+ "net/http/pprof"
+ "os"
+ "sync"
+
+ "github.com/bilibili/Kratos/pkg/conf/dsn"
+
+ "github.com/pkg/errors"
+)
+
+var (
+ _perfOnce sync.Once
+ _perfDSN string
+)
+
+func init() {
+ v := os.Getenv("HTTP_PERF")
+ if v == "" {
+ v = "tcp://0.0.0.0:2333"
+ }
+ flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.")
+}
+
+func startPerf() {
+ _perfOnce.Do(func() {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/debug/pprof/", pprof.Index)
+ mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
+ mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
+ mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
+
+ go func() {
+ d, err := dsn.Parse(_perfDSN)
+ if err != nil {
+ panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err))
+ }
+ if err := http.ListenAndServe(d.Host, mux); err != nil {
+ panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err))
+ }
+ }()
+ })
+}
diff --git a/pkg/net/http/blademaster/prometheus.go b/pkg/net/http/blademaster/prometheus.go
new file mode 100644
index 000000000..68af2ee58
--- /dev/null
+++ b/pkg/net/http/blademaster/prometheus.go
@@ -0,0 +1,12 @@
+package blademaster
+
+import (
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+)
+
+func monitor() HandlerFunc {
+ return func(c *Context) {
+ h := promhttp.Handler()
+ h.ServeHTTP(c.Writer, c.Request)
+ }
+}
diff --git a/pkg/net/http/blademaster/recovery.go b/pkg/net/http/blademaster/recovery.go
new file mode 100644
index 000000000..403d8a5d9
--- /dev/null
+++ b/pkg/net/http/blademaster/recovery.go
@@ -0,0 +1,32 @@
+package blademaster
+
+import (
+ "fmt"
+ "net/http/httputil"
+ "os"
+ "runtime"
+
+ "github.com/bilibili/Kratos/pkg/log"
+)
+
+// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
+func Recovery() HandlerFunc {
+ return func(c *Context) {
+ defer func() {
+ var rawReq []byte
+ if err := recover(); err != nil {
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ if c.Request != nil {
+ rawReq, _ = httputil.DumpRequest(c.Request, false)
+ }
+ pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf)
+ fmt.Fprintf(os.Stderr, pl)
+ log.Error(pl)
+ c.AbortWithStatus(500)
+ }
+ }()
+ c.Next()
+ }
+}
diff --git a/pkg/net/http/blademaster/render/data.go b/pkg/net/http/blademaster/render/data.go
new file mode 100644
index 000000000..d602350b0
--- /dev/null
+++ b/pkg/net/http/blademaster/render/data.go
@@ -0,0 +1,30 @@
+package render
+
+import (
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+// Data common bytes struct.
+type Data struct {
+ ContentType string
+ Data [][]byte
+}
+
+// Render (Data) writes data with custom ContentType.
+func (r Data) Render(w http.ResponseWriter) (err error) {
+ r.WriteContentType(w)
+ for _, d := range r.Data {
+ if _, err = w.Write(d); err != nil {
+ err = errors.WithStack(err)
+ return
+ }
+ }
+ return
+}
+
+// WriteContentType writes data with custom ContentType.
+func (r Data) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, []string{r.ContentType})
+}
diff --git a/pkg/net/http/blademaster/render/json.go b/pkg/net/http/blademaster/render/json.go
new file mode 100644
index 000000000..5a5f23bff
--- /dev/null
+++ b/pkg/net/http/blademaster/render/json.go
@@ -0,0 +1,58 @@
+package render
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+var jsonContentType = []string{"application/json; charset=utf-8"}
+
+// JSON common json struct.
+type JSON struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ TTL int `json:"ttl"`
+ Data interface{} `json:"data,omitempty"`
+}
+
+func writeJSON(w http.ResponseWriter, obj interface{}) (err error) {
+ var jsonBytes []byte
+ writeContentType(w, jsonContentType)
+ if jsonBytes, err = json.Marshal(obj); err != nil {
+ err = errors.WithStack(err)
+ return
+ }
+ if _, err = w.Write(jsonBytes); err != nil {
+ err = errors.WithStack(err)
+ }
+ return
+}
+
+// Render (JSON) writes data with json ContentType.
+func (r JSON) Render(w http.ResponseWriter) error {
+ // FIXME(zhoujiahui): the TTL field will be configurable in the future
+ if r.TTL <= 0 {
+ r.TTL = 1
+ }
+ return writeJSON(w, r)
+}
+
+// WriteContentType write json ContentType.
+func (r JSON) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, jsonContentType)
+}
+
+// MapJSON common map json struct.
+type MapJSON map[string]interface{}
+
+// Render (MapJSON) writes data with json ContentType.
+func (m MapJSON) Render(w http.ResponseWriter) error {
+ return writeJSON(w, m)
+}
+
+// WriteContentType write json ContentType.
+func (m MapJSON) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, jsonContentType)
+}
diff --git a/pkg/net/http/blademaster/render/protobuf.go b/pkg/net/http/blademaster/render/protobuf.go
new file mode 100644
index 000000000..4664f2b5f
--- /dev/null
+++ b/pkg/net/http/blademaster/render/protobuf.go
@@ -0,0 +1,38 @@
+package render
+
+import (
+ "net/http"
+
+ "github.com/gogo/protobuf/proto"
+ "github.com/pkg/errors"
+)
+
+var pbContentType = []string{"application/x-protobuf"}
+
+// Render (PB) writes data with protobuf ContentType.
+func (r PB) Render(w http.ResponseWriter) error {
+ if r.TTL <= 0 {
+ r.TTL = 1
+ }
+ return writePB(w, r)
+}
+
+// WriteContentType write protobuf ContentType.
+func (r PB) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, pbContentType)
+}
+
+func writePB(w http.ResponseWriter, obj PB) (err error) {
+ var pbBytes []byte
+ writeContentType(w, pbContentType)
+
+ if pbBytes, err = proto.Marshal(&obj); err != nil {
+ err = errors.WithStack(err)
+ return
+ }
+
+ if _, err = w.Write(pbBytes); err != nil {
+ err = errors.WithStack(err)
+ }
+ return
+}
diff --git a/pkg/net/http/blademaster/render/redirect.go b/pkg/net/http/blademaster/render/redirect.go
new file mode 100644
index 000000000..73e516d65
--- /dev/null
+++ b/pkg/net/http/blademaster/render/redirect.go
@@ -0,0 +1,26 @@
+package render
+
+import (
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+// Redirect render for redirect to specified location.
+type Redirect struct {
+ Code int
+ Request *http.Request
+ Location string
+}
+
+// Render (Redirect) redirect to specified location.
+func (r Redirect) Render(w http.ResponseWriter) error {
+ if (r.Code < 300 || r.Code > 308) && r.Code != 201 {
+ return errors.Errorf("Cannot redirect with status code %d", r.Code)
+ }
+ http.Redirect(w, r.Request, r.Location, r.Code)
+ return nil
+}
+
+// WriteContentType noneContentType.
+func (r Redirect) WriteContentType(http.ResponseWriter) {}
diff --git a/pkg/net/http/blademaster/render/render.go b/pkg/net/http/blademaster/render/render.go
new file mode 100644
index 000000000..13188637e
--- /dev/null
+++ b/pkg/net/http/blademaster/render/render.go
@@ -0,0 +1,30 @@
+package render
+
+import (
+ "net/http"
+)
+
+// Render http reponse render.
+type Render interface {
+ // Render render it to http response writer.
+ Render(http.ResponseWriter) error
+ // WriteContentType write content-type to http response writer.
+ WriteContentType(w http.ResponseWriter)
+}
+
+var (
+ _ Render = JSON{}
+ _ Render = MapJSON{}
+ _ Render = XML{}
+ _ Render = String{}
+ _ Render = Redirect{}
+ _ Render = Data{}
+ _ Render = PB{}
+)
+
+func writeContentType(w http.ResponseWriter, value []string) {
+ header := w.Header()
+ if val := header["Content-Type"]; len(val) == 0 {
+ header["Content-Type"] = value
+ }
+}
diff --git a/pkg/net/http/blademaster/render/render.pb.go b/pkg/net/http/blademaster/render/render.pb.go
new file mode 100644
index 000000000..bb5390e98
--- /dev/null
+++ b/pkg/net/http/blademaster/render/render.pb.go
@@ -0,0 +1,89 @@
+// Code generated by protoc-gen-gogo. DO NOT EDIT.
+// source: pb.proto
+
+/*
+Package render is a generated protocol buffer package.
+
+It is generated from these files:
+ pb.proto
+
+It has these top-level messages:
+ PB
+*/
+package render
+
+import proto "github.com/gogo/protobuf/proto"
+import fmt "fmt"
+import math "math"
+import google_protobuf "github.com/gogo/protobuf/types"
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
+
+type PB struct {
+ Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"`
+ Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"`
+ TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"`
+ Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"`
+}
+
+func (m *PB) Reset() { *m = PB{} }
+func (m *PB) String() string { return proto.CompactTextString(m) }
+func (*PB) ProtoMessage() {}
+func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} }
+
+func (m *PB) GetCode() int64 {
+ if m != nil {
+ return m.Code
+ }
+ return 0
+}
+
+func (m *PB) GetMessage() string {
+ if m != nil {
+ return m.Message
+ }
+ return ""
+}
+
+func (m *PB) GetTTL() uint64 {
+ if m != nil {
+ return m.TTL
+ }
+ return 0
+}
+
+func (m *PB) GetData() *google_protobuf.Any {
+ if m != nil {
+ return m.Data
+ }
+ return nil
+}
+
+func init() {
+ proto.RegisterType((*PB)(nil), "render.PB")
+}
+
+func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) }
+
+var fileDescriptorPb = []byte{
+ // 154 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b,
+ 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf,
+ 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94,
+ 0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30,
+ 0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89,
+ 0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48,
+ 0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92,
+ 0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66,
+ 0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00,
+ 0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00,
+}
diff --git a/pkg/net/http/blademaster/render/render.proto b/pkg/net/http/blademaster/render/render.proto
new file mode 100644
index 000000000..e3f53870f
--- /dev/null
+++ b/pkg/net/http/blademaster/render/render.proto
@@ -0,0 +1,14 @@
+// use under command to generate pb.pb.go
+// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto
+syntax = "proto3";
+package render;
+
+import "google/protobuf/any.proto";
+import "github.com/gogo/protobuf/gogoproto/gogo.proto";
+
+message PB {
+ int64 Code = 1;
+ string Message = 2;
+ uint64 TTL = 3;
+ google.protobuf.Any Data = 4;
+}
\ No newline at end of file
diff --git a/pkg/net/http/blademaster/render/string.go b/pkg/net/http/blademaster/render/string.go
new file mode 100644
index 000000000..4112b5713
--- /dev/null
+++ b/pkg/net/http/blademaster/render/string.go
@@ -0,0 +1,40 @@
+package render
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+var plainContentType = []string{"text/plain; charset=utf-8"}
+
+// String common string struct.
+type String struct {
+ Format string
+ Data []interface{}
+}
+
+// Render (String) writes data with custom ContentType.
+func (r String) Render(w http.ResponseWriter) error {
+ return writeString(w, r.Format, r.Data)
+}
+
+// WriteContentType writes string with text/plain ContentType.
+func (r String) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, plainContentType)
+}
+
+func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) {
+ writeContentType(w, plainContentType)
+ if len(data) > 0 {
+ _, err = fmt.Fprintf(w, format, data...)
+ } else {
+ _, err = io.WriteString(w, format)
+ }
+ if err != nil {
+ err = errors.WithStack(err)
+ }
+ return
+}
diff --git a/pkg/net/http/blademaster/render/xml.go b/pkg/net/http/blademaster/render/xml.go
new file mode 100644
index 000000000..8837c582c
--- /dev/null
+++ b/pkg/net/http/blademaster/render/xml.go
@@ -0,0 +1,31 @@
+package render
+
+import (
+ "encoding/xml"
+ "net/http"
+
+ "github.com/pkg/errors"
+)
+
+// XML common xml struct.
+type XML struct {
+ Code int
+ Message string
+ Data interface{}
+}
+
+var xmlContentType = []string{"application/xml; charset=utf-8"}
+
+// Render (XML) writes data with xml ContentType.
+func (r XML) Render(w http.ResponseWriter) (err error) {
+ r.WriteContentType(w)
+ if err = xml.NewEncoder(w).Encode(r.Data); err != nil {
+ err = errors.WithStack(err)
+ }
+ return
+}
+
+// WriteContentType write xml ContentType.
+func (r XML) WriteContentType(w http.ResponseWriter) {
+ writeContentType(w, xmlContentType)
+}
diff --git a/pkg/net/http/blademaster/routergroup.go b/pkg/net/http/blademaster/routergroup.go
new file mode 100644
index 000000000..28d09a805
--- /dev/null
+++ b/pkg/net/http/blademaster/routergroup.go
@@ -0,0 +1,166 @@
+package blademaster
+
+import (
+ "regexp"
+)
+
+// IRouter http router framework interface.
+type IRouter interface {
+ IRoutes
+ Group(string, ...HandlerFunc) *RouterGroup
+}
+
+// IRoutes http router interface.
+type IRoutes interface {
+ UseFunc(...HandlerFunc) IRoutes
+ Use(...Handler) IRoutes
+
+ Handle(string, string, ...HandlerFunc) IRoutes
+ HEAD(string, ...HandlerFunc) IRoutes
+ GET(string, ...HandlerFunc) IRoutes
+ POST(string, ...HandlerFunc) IRoutes
+ PUT(string, ...HandlerFunc) IRoutes
+ DELETE(string, ...HandlerFunc) IRoutes
+}
+
+// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
+// and an array of handlers (middleware).
+type RouterGroup struct {
+ Handlers []HandlerFunc
+ basePath string
+ engine *Engine
+ root bool
+ baseConfig *MethodConfig
+}
+
+var _ IRouter = &RouterGroup{}
+
+// Use adds middleware to the group, see example code in doc.
+func (group *RouterGroup) Use(middleware ...Handler) IRoutes {
+ for _, m := range middleware {
+ group.Handlers = append(group.Handlers, m.ServeHTTP)
+ }
+ return group.returnObj()
+}
+
+// UseFunc adds middleware to the group, see example code in doc.
+func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes {
+ group.Handlers = append(group.Handlers, middleware...)
+ return group.returnObj()
+}
+
+// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
+// For example, all the routes that use a common middlware for authorization could be grouped.
+func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
+ return &RouterGroup{
+ Handlers: group.combineHandlers(handlers),
+ basePath: group.calculateAbsolutePath(relativePath),
+ engine: group.engine,
+ root: false,
+ }
+}
+
+// SetMethodConfig is used to set config on specified method
+func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup {
+ group.baseConfig = config
+ return group
+}
+
+// BasePath router group base path.
+func (group *RouterGroup) BasePath() string {
+ return group.basePath
+}
+
+func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
+ absolutePath := group.calculateAbsolutePath(relativePath)
+ injections := group.injections(relativePath)
+ handlers = group.combineHandlers(injections, handlers)
+ group.engine.addRoute(httpMethod, absolutePath, handlers...)
+ if group.baseConfig != nil {
+ group.engine.SetMethodConfig(absolutePath, group.baseConfig)
+ }
+ return group.returnObj()
+}
+
+// Handle registers a new request handle and middleware with the given path and method.
+// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
+// See the example code in doc.
+//
+// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
+// functions can be used.
+//
+// This function is intended for bulk loading and to allow the usage of less
+// frequently used, non-standardized or custom methods (e.g. for internal
+// communication with a proxy).
+func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
+ if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil {
+ panic("http method " + httpMethod + " is not valid")
+ }
+ return group.handle(httpMethod, relativePath, handlers...)
+}
+
+// HEAD is a shortcut for router.Handle("HEAD", path, handle).
+func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes {
+ return group.handle("HEAD", relativePath, handlers...)
+}
+
+// GET is a shortcut for router.Handle("GET", path, handle).
+func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
+ return group.handle("GET", relativePath, handlers...)
+}
+
+// POST is a shortcut for router.Handle("POST", path, handle).
+func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes {
+ return group.handle("POST", relativePath, handlers...)
+}
+
+// PUT is a shortcut for router.Handle("PUT", path, handle).
+func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes {
+ return group.handle("PUT", relativePath, handlers...)
+}
+
+// DELETE is a shortcut for router.Handle("DELETE", path, handle).
+func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes {
+ return group.handle("DELETE", relativePath, handlers...)
+}
+
+func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc {
+ finalSize := len(group.Handlers)
+ for _, handlers := range handlerGroups {
+ finalSize += len(handlers)
+ }
+ if finalSize >= int(_abortIndex) {
+ panic("too many handlers")
+ }
+ mergedHandlers := make([]HandlerFunc, finalSize)
+ copy(mergedHandlers, group.Handlers)
+ position := len(group.Handlers)
+ for _, handlers := range handlerGroups {
+ copy(mergedHandlers[position:], handlers)
+ position += len(handlers)
+ }
+ return mergedHandlers
+}
+
+func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
+ return joinPaths(group.basePath, relativePath)
+}
+
+func (group *RouterGroup) returnObj() IRoutes {
+ if group.root {
+ return group.engine
+ }
+ return group
+}
+
+// injections is
+func (group *RouterGroup) injections(relativePath string) []HandlerFunc {
+ absPath := group.calculateAbsolutePath(relativePath)
+ for _, injection := range group.engine.injections {
+ if !injection.pattern.MatchString(absPath) {
+ continue
+ }
+ return injection.handlers
+ }
+ return nil
+}
diff --git a/pkg/net/http/blademaster/server.go b/pkg/net/http/blademaster/server.go
new file mode 100644
index 000000000..c96513ba3
--- /dev/null
+++ b/pkg/net/http/blademaster/server.go
@@ -0,0 +1,445 @@
+package blademaster
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "regexp"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/bilibili/Kratos/pkg/conf/dsn"
+ "github.com/bilibili/Kratos/pkg/log"
+ "github.com/bilibili/Kratos/pkg/net/ip"
+ "github.com/bilibili/Kratos/pkg/net/metadata"
+ "github.com/bilibili/Kratos/pkg/stat"
+ xtime "github.com/bilibili/Kratos/pkg/time"
+
+ "github.com/pkg/errors"
+)
+
+const (
+ defaultMaxMemory = 32 << 20 // 32 MB
+)
+
+var (
+ _ IRouter = &Engine{}
+ stats = stat.HTTPServer
+
+ _httpDSN string
+)
+
+func init() {
+ addFlag(flag.CommandLine)
+}
+
+func addFlag(fs *flag.FlagSet) {
+ v := os.Getenv("HTTP")
+ if v == "" {
+ v = "tcp://0.0.0.0:8000/?timeout=1s"
+ }
+ fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.")
+}
+
+func parseDSN(rawdsn string) *ServerConfig {
+ conf := new(ServerConfig)
+ d, err := dsn.Parse(rawdsn)
+ if err != nil {
+ panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
+ }
+ if _, err = d.Bind(conf); err != nil {
+ panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
+ }
+ return conf
+}
+
+// Handler responds to an HTTP request.
+type Handler interface {
+ ServeHTTP(c *Context)
+}
+
+// HandlerFunc http request handler function.
+type HandlerFunc func(*Context)
+
+// ServeHTTP calls f(ctx).
+func (f HandlerFunc) ServeHTTP(c *Context) {
+ f(c)
+}
+
+// ServerConfig is the bm server config model
+type ServerConfig struct {
+ Network string `dsn:"network"`
+ // FIXME: rename to Address
+ Addr string `dsn:"address"`
+ Timeout xtime.Duration `dsn:"query.timeout"`
+ ReadTimeout xtime.Duration `dsn:"query.readTimeout"`
+ WriteTimeout xtime.Duration `dsn:"query.writeTimeout"`
+}
+
+// MethodConfig is
+type MethodConfig struct {
+ Timeout xtime.Duration
+}
+
+// Start listen and serve bm engine by given DSN.
+func (engine *Engine) Start() error {
+ conf := engine.conf
+ l, err := net.Listen(conf.Network, conf.Addr)
+ if err != nil {
+ errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr)
+ return err
+ }
+
+ log.Info("blademaster: start http listen addr: %s", conf.Addr)
+ server := &http.Server{
+ ReadTimeout: time.Duration(conf.ReadTimeout),
+ WriteTimeout: time.Duration(conf.WriteTimeout),
+ }
+ go func() {
+ if err := engine.RunServer(server, l); err != nil {
+ if errors.Cause(err) == http.ErrServerClosed {
+ log.Info("blademaster: server closed")
+ return
+ }
+ panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l))
+ }
+ }()
+
+ return nil
+}
+
+// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
+// Create an instance of Engine, by using New() or Default()
+type Engine struct {
+ RouterGroup
+
+ lock sync.RWMutex
+ conf *ServerConfig
+
+ address string
+
+ mux *http.ServeMux // http mux router
+ server atomic.Value // store *http.Server
+ metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
+
+ pcLock sync.RWMutex
+ methodConfigs map[string]*MethodConfig
+
+ injections []injection
+}
+
+type injection struct {
+ pattern *regexp.Regexp
+ handlers []HandlerFunc
+}
+
+// New returns a new blank Engine instance without any middleware attached.
+//
+// Deprecated: please use NewServer.
+func New() *Engine {
+ engine := &Engine{
+ RouterGroup: RouterGroup{
+ Handlers: nil,
+ basePath: "/",
+ root: true,
+ },
+ address: ip.InternalIP(),
+ conf: &ServerConfig{
+ Timeout: xtime.Duration(time.Second),
+ },
+ mux: http.NewServeMux(),
+ metastore: make(map[string]map[string]interface{}),
+ methodConfigs: make(map[string]*MethodConfig),
+ injections: make([]injection, 0),
+ }
+ engine.RouterGroup.engine = engine
+ // NOTE add prometheus monitor location
+ engine.addRoute("GET", "/metrics", monitor())
+ engine.addRoute("GET", "/metadata", engine.metadata())
+ startPerf()
+ return engine
+}
+
+// NewServer returns a new blank Engine instance without any middleware attached.
+func NewServer(conf *ServerConfig) *Engine {
+ if conf == nil {
+ if !flag.Parsed() {
+ fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n")
+ }
+ conf = parseDSN(_httpDSN)
+ } else {
+ fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n")
+ }
+
+ engine := &Engine{
+ RouterGroup: RouterGroup{
+ Handlers: nil,
+ basePath: "/",
+ root: true,
+ },
+ address: ip.InternalIP(),
+ mux: http.NewServeMux(),
+ metastore: make(map[string]map[string]interface{}),
+ methodConfigs: make(map[string]*MethodConfig),
+ }
+ if err := engine.SetConfig(conf); err != nil {
+ panic(err)
+ }
+ engine.RouterGroup.engine = engine
+ // NOTE add prometheus monitor location
+ engine.addRoute("GET", "/metrics", monitor())
+ engine.addRoute("GET", "/metadata", engine.metadata())
+ startPerf()
+ return engine
+}
+
+// SetMethodConfig is used to set config on specified path
+func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) {
+ engine.pcLock.Lock()
+ engine.methodConfigs[path] = mc
+ engine.pcLock.Unlock()
+}
+
+// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
+func DefaultServer(conf *ServerConfig) *Engine {
+ engine := NewServer(conf)
+ engine.Use(Recovery(), Trace(), Logger())
+ return engine
+}
+
+// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
+//
+// Deprecated: please use DefaultServer.
+func Default() *Engine {
+ engine := New()
+ engine.Use(Recovery(), Trace(), Logger())
+ return engine
+}
+
+func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) {
+ if path[0] != '/' {
+ panic("blademaster: path must begin with '/'")
+ }
+ if method == "" {
+ panic("blademaster: HTTP method can not be empty")
+ }
+ if len(handlers) == 0 {
+ panic("blademaster: there must be at least one handler")
+ }
+ if _, ok := engine.metastore[path]; !ok {
+ engine.metastore[path] = make(map[string]interface{})
+ }
+ engine.metastore[path]["method"] = method
+ engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
+ c := &Context{
+ Context: nil,
+ engine: engine,
+ index: -1,
+ handlers: nil,
+ Keys: nil,
+ method: "",
+ Error: nil,
+ }
+
+ c.Request = req
+ c.Writer = w
+ c.handlers = handlers
+ c.method = method
+
+ engine.handleContext(c)
+ })
+}
+
+// SetConfig is used to set the engine configuration.
+// Only the valid config will be loaded.
+func (engine *Engine) SetConfig(conf *ServerConfig) (err error) {
+ if conf.Timeout <= 0 {
+ return errors.New("blademaster: config timeout must greater than 0")
+ }
+ if conf.Network == "" {
+ conf.Network = "tcp"
+ }
+ engine.lock.Lock()
+ engine.conf = conf
+ engine.lock.Unlock()
+ return
+}
+
+func (engine *Engine) methodConfig(path string) *MethodConfig {
+ engine.pcLock.RLock()
+ mc := engine.methodConfigs[path]
+ engine.pcLock.RUnlock()
+ return mc
+}
+
+func (engine *Engine) handleContext(c *Context) {
+ var cancel func()
+ req := c.Request
+ ctype := req.Header.Get("Content-Type")
+ switch {
+ case strings.Contains(ctype, "multipart/form-data"):
+ req.ParseMultipartForm(defaultMaxMemory)
+ default:
+ req.ParseForm()
+ }
+ // get derived timeout from http request header,
+ // compare with the engine configured,
+ // and use the minimum one
+ engine.lock.RLock()
+ tm := time.Duration(engine.conf.Timeout)
+ engine.lock.RUnlock()
+ // the method config is preferred
+ if pc := engine.methodConfig(c.Request.URL.Path); pc != nil {
+ tm = time.Duration(pc.Timeout)
+ }
+ if ctm := timeout(req); ctm > 0 && tm > ctm {
+ tm = ctm
+ }
+ md := metadata.MD{
+ metadata.Color: color(req),
+ metadata.RemoteIP: remoteIP(req),
+ metadata.RemotePort: remotePort(req),
+ metadata.Caller: caller(req),
+ metadata.Mirror: mirror(req),
+ }
+ ctx := metadata.NewContext(context.Background(), md)
+ if tm > 0 {
+ c.Context, cancel = context.WithTimeout(ctx, tm)
+ } else {
+ c.Context, cancel = context.WithCancel(ctx)
+ }
+ defer cancel()
+ c.Next()
+}
+
+// Router return a http.Handler for using http.ListenAndServe() directly.
+func (engine *Engine) Router() http.Handler {
+ return engine.mux
+}
+
+// Server is used to load stored http server.
+func (engine *Engine) Server() *http.Server {
+ s, ok := engine.server.Load().(*http.Server)
+ if !ok {
+ return nil
+ }
+ return s
+}
+
+// Shutdown the http server without interrupting active connections.
+func (engine *Engine) Shutdown(ctx context.Context) error {
+ server := engine.Server()
+ if server == nil {
+ return errors.New("blademaster: no server")
+ }
+ return errors.WithStack(server.Shutdown(ctx))
+}
+
+// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
+// included in the handlers chain for every single request. Even 404, 405, static files...
+// For example, this is the right place for a logger or error management middleware.
+func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes {
+ engine.RouterGroup.UseFunc(middleware...)
+ return engine
+}
+
+// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
+// included in the handlers chain for every single request. Even 404, 405, static files...
+// For example, this is the right place for a logger or error management middleware.
+func (engine *Engine) Use(middleware ...Handler) IRoutes {
+ engine.RouterGroup.Use(middleware...)
+ return engine
+}
+
+// Ping is used to set the general HTTP ping handler.
+func (engine *Engine) Ping(handler HandlerFunc) {
+ engine.GET("/monitor/ping", handler)
+}
+
+// Register is used to export metadata to discovery.
+func (engine *Engine) Register(handler HandlerFunc) {
+ engine.GET("/register", handler)
+}
+
+// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
+// It is a shortcut for http.ListenAndServe(addr, router)
+// Note: this method will block the calling goroutine indefinitely unless an error happens.
+func (engine *Engine) Run(addr ...string) (err error) {
+ address := resolveAddress(addr)
+ server := &http.Server{
+ Addr: address,
+ Handler: engine.mux,
+ }
+ engine.server.Store(server)
+ if err = server.ListenAndServe(); err != nil {
+ err = errors.Wrapf(err, "addrs: %v", addr)
+ }
+ return
+}
+
+// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
+// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
+// Note: this method will block the calling goroutine indefinitely unless an error happens.
+func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
+ server := &http.Server{
+ Addr: addr,
+ Handler: engine.mux,
+ }
+ engine.server.Store(server)
+ if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
+ err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile)
+ }
+ return
+}
+
+// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
+// through the specified unix socket (ie. a file).
+// Note: this method will block the calling goroutine indefinitely unless an error happens.
+func (engine *Engine) RunUnix(file string) (err error) {
+ os.Remove(file)
+ listener, err := net.Listen("unix", file)
+ if err != nil {
+ err = errors.Wrapf(err, "unix: %s", file)
+ return
+ }
+ defer listener.Close()
+ server := &http.Server{
+ Handler: engine.mux,
+ }
+ engine.server.Store(server)
+ if err = server.Serve(listener); err != nil {
+ err = errors.Wrapf(err, "unix: %s", file)
+ }
+ return
+}
+
+// RunServer will serve and start listening HTTP requests by given server and listener.
+// Note: this method will block the calling goroutine indefinitely unless an error happens.
+func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) {
+ server.Handler = engine.mux
+ engine.server.Store(server)
+ if err = server.Serve(l); err != nil {
+ err = errors.Wrapf(err, "listen server: %+v/%+v", server, l)
+ return
+ }
+ return
+}
+
+func (engine *Engine) metadata() HandlerFunc {
+ return func(c *Context) {
+ c.JSON(engine.metastore, nil)
+ }
+}
+
+// Inject is
+func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
+ engine.injections = append(engine.injections, injection{
+ pattern: regexp.MustCompile(pattern),
+ handlers: handlers,
+ })
+}
diff --git a/pkg/net/http/blademaster/trace.go b/pkg/net/http/blademaster/trace.go
index b792246bb..f607f18ca 100644
--- a/pkg/net/http/blademaster/trace.go
+++ b/pkg/net/http/blademaster/trace.go
@@ -4,12 +4,43 @@ import (
"io"
"net/http"
"net/http/httptrace"
+ "strconv"
+ "github.com/bilibili/Kratos/pkg/net/metadata"
"github.com/bilibili/Kratos/pkg/net/trace"
)
const _defaultComponentName = "net/http"
+// Trace is trace middleware
+func Trace() HandlerFunc {
+ return func(c *Context) {
+ // handle http request
+ // get derived trace from http request header
+ t, err := trace.Extract(trace.HTTPFormat, c.Request.Header)
+ if err != nil {
+ var opts []trace.Option
+ if ok, _ := strconv.ParseBool(trace.KratosTraceDebug); ok {
+ opts = append(opts, trace.EnableDebug())
+ }
+ t = trace.New(c.Request.URL.Path, opts...)
+ }
+ t.SetTitle(c.Request.URL.Path)
+ t.SetTag(trace.String(trace.TagComponent, _defaultComponentName))
+ t.SetTag(trace.String(trace.TagHTTPMethod, c.Request.Method))
+ t.SetTag(trace.String(trace.TagHTTPURL, c.Request.URL.String()))
+ t.SetTag(trace.String(trace.TagSpanKind, "server"))
+ // business tag
+ t.SetTag(trace.String("caller", metadata.String(c.Context, metadata.Caller)))
+ // export trace id to user.
+ // TODO(zhoujiahui): trace package should be updated
+ // c.Writer.Header().Set(trace.KratosTraceID, t.TraceID())
+ c.Context = trace.NewContext(c.Context, t)
+ c.Next()
+ t.Finish(&c.Error)
+ }
+}
+
type closeTracker struct {
io.ReadCloser
tr trace.Trace
diff --git a/pkg/net/http/blademaster/utils.go b/pkg/net/http/blademaster/utils.go
new file mode 100644
index 000000000..54da96d3b
--- /dev/null
+++ b/pkg/net/http/blademaster/utils.go
@@ -0,0 +1,42 @@
+package blademaster
+
+import (
+ "os"
+ "path"
+)
+
+func lastChar(str string) uint8 {
+ if str == "" {
+ panic("The length of the string can't be 0")
+ }
+ return str[len(str)-1]
+}
+
+func joinPaths(absolutePath, relativePath string) string {
+ if relativePath == "" {
+ return absolutePath
+ }
+
+ finalPath := path.Join(absolutePath, relativePath)
+ appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/'
+ if appendSlash {
+ return finalPath + "/"
+ }
+ return finalPath
+}
+
+func resolveAddress(addr []string) string {
+ switch len(addr) {
+ case 0:
+ if port := os.Getenv("PORT"); port != "" {
+ //debugPrint("Environment variable PORT=\"%s\"", port)
+ return ":" + port
+ }
+ //debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
+ return ":8080"
+ case 1:
+ return addr[0]
+ default:
+ panic("too much parameters")
+ }
+}