blademaster initial (#6)
parent
1efe0a084e
commit
96d32e866a
@ -0,0 +1,85 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
"strings" |
||||
|
||||
"gopkg.in/go-playground/validator.v9" |
||||
) |
||||
|
||||
// MIME
|
||||
const ( |
||||
MIMEJSON = "application/json" |
||||
MIMEHTML = "text/html" |
||||
MIMEXML = "application/xml" |
||||
MIMEXML2 = "text/xml" |
||||
MIMEPlain = "text/plain" |
||||
MIMEPOSTForm = "application/x-www-form-urlencoded" |
||||
MIMEMultipartPOSTForm = "multipart/form-data" |
||||
) |
||||
|
||||
// Binding http binding request interface.
|
||||
type Binding interface { |
||||
Name() string |
||||
Bind(*http.Request, interface{}) error |
||||
} |
||||
|
||||
// StructValidator http validator interface.
|
||||
type StructValidator interface { |
||||
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||
// If the received type is not a struct, any validation should be skipped and nil must be returned.
|
||||
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||
// Otherwise nil must be returned.
|
||||
ValidateStruct(interface{}) error |
||||
|
||||
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||
RegisterValidation(string, validator.Func) error |
||||
} |
||||
|
||||
// Validator default validator.
|
||||
var Validator StructValidator = &defaultValidator{} |
||||
|
||||
// Binding
|
||||
var ( |
||||
JSON = jsonBinding{} |
||||
XML = xmlBinding{} |
||||
Form = formBinding{} |
||||
Query = queryBinding{} |
||||
FormPost = formPostBinding{} |
||||
FormMultipart = formMultipartBinding{} |
||||
) |
||||
|
||||
// Default get by binding type by method and contexttype.
|
||||
func Default(method, contentType string) Binding { |
||||
if method == "GET" { |
||||
return Form |
||||
} |
||||
|
||||
contentType = stripContentTypeParam(contentType) |
||||
switch contentType { |
||||
case MIMEJSON: |
||||
return JSON |
||||
case MIMEXML, MIMEXML2: |
||||
return XML |
||||
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
|
||||
return Form |
||||
} |
||||
} |
||||
|
||||
func validate(obj interface{}) error { |
||||
if Validator == nil { |
||||
return nil |
||||
} |
||||
return Validator.ValidateStruct(obj) |
||||
} |
||||
|
||||
func stripContentTypeParam(contentType string) string { |
||||
i := strings.Index(contentType, ";") |
||||
if i != -1 { |
||||
contentType = contentType[:i] |
||||
} |
||||
return contentType |
||||
} |
@ -0,0 +1,342 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"bytes" |
||||
"mime/multipart" |
||||
"net/http" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type FooStruct struct { |
||||
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"` |
||||
} |
||||
|
||||
type FooBarStruct struct { |
||||
FooStruct |
||||
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"` |
||||
Slice []string `form:"slice" validate:"max=10"` |
||||
} |
||||
|
||||
type ComplexDefaultStruct struct { |
||||
Int int `form:"int" default:"999"` |
||||
String string `form:"string" default:"default-string"` |
||||
Bool bool `form:"bool" default:"false"` |
||||
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"` |
||||
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"` |
||||
} |
||||
|
||||
type Int8SliceStruct struct { |
||||
State []int8 `form:"state,split"` |
||||
} |
||||
|
||||
type Int64SliceStruct struct { |
||||
State []int64 `form:"state,split"` |
||||
} |
||||
|
||||
type StringSliceStruct struct { |
||||
State []string `form:"state,split"` |
||||
} |
||||
|
||||
func TestBindingDefault(t *testing.T) { |
||||
assert.Equal(t, Default("GET", ""), Form) |
||||
assert.Equal(t, Default("GET", MIMEJSON), Form) |
||||
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON), JSON) |
||||
assert.Equal(t, Default("PUT", MIMEJSON), JSON) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON) |
||||
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEXML), XML) |
||||
assert.Equal(t, Default("PUT", MIMEXML2), XML) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm), Form) |
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||
|
||||
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) |
||||
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) |
||||
|
||||
} |
||||
|
||||
func TestStripContentType(t *testing.T) { |
||||
c1 := "application/vnd.mozilla.xul+xml" |
||||
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8" |
||||
assert.Equal(t, stripContentTypeParam(c1), c1) |
||||
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml") |
||||
} |
||||
|
||||
func TestBindInt8Form(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(Int8SliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []int8{1, 2, 3}, q.State) |
||||
|
||||
params = "state=1,2,3,256" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.Error(t, Form.Bind(req, q)) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
|
||||
params = "state=1,,2" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int8SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.EqualValues(t, []int8{1, 2}, q.State) |
||||
} |
||||
|
||||
func TestBindInt64Form(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(Int64SliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []int64{1, 2, 3}, q.State) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(Int64SliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
} |
||||
|
||||
func TestBindStringForm(t *testing.T) { |
||||
params := "state=1,2,3" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(StringSliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []string{"1", "2", "3"}, q.State) |
||||
|
||||
params = "state=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(StringSliceStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Len(t, q.State, 0) |
||||
|
||||
params = "state=p,,p" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(StringSliceStruct) |
||||
Form.Bind(req, q) |
||||
assert.EqualValues(t, []string{"p", "p"}, q.State) |
||||
} |
||||
|
||||
func TestBindingJSON(t *testing.T) { |
||||
testBodyBinding(t, |
||||
JSON, "json", |
||||
"/", "/", |
||||
`{"foo": "bar"}`, `{"bar": "foo"}`) |
||||
} |
||||
|
||||
func TestBindingForm(t *testing.T) { |
||||
testFormBinding(t, "POST", |
||||
"/", "/", |
||||
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo") |
||||
} |
||||
|
||||
func TestBindingForm2(t *testing.T) { |
||||
testFormBinding(t, "GET", |
||||
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||
"", "") |
||||
} |
||||
|
||||
func TestBindingQuery(t *testing.T) { |
||||
testQueryBinding(t, "POST", |
||||
"/?foo=bar&bar=foo", "/", |
||||
"foo=unused", "bar2=foo") |
||||
} |
||||
|
||||
func TestBindingQuery2(t *testing.T) { |
||||
testQueryBinding(t, "GET", |
||||
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||
"foo=unused", "") |
||||
} |
||||
|
||||
func TestBindingXML(t *testing.T) { |
||||
testBodyBinding(t, |
||||
XML, "xml", |
||||
"/", "/", |
||||
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>") |
||||
} |
||||
|
||||
func createFormPostRequest() *http.Request { |
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) |
||||
req.Header.Set("Content-Type", MIMEPOSTForm) |
||||
return req |
||||
} |
||||
|
||||
func createFormMultipartRequest() *http.Request { |
||||
boundary := "--testboundary" |
||||
body := new(bytes.Buffer) |
||||
mw := multipart.NewWriter(body) |
||||
defer mw.Close() |
||||
|
||||
mw.SetBoundary(boundary) |
||||
mw.WriteField("foo", "bar") |
||||
mw.WriteField("bar", "foo") |
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) |
||||
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) |
||||
return req |
||||
} |
||||
|
||||
func TestBindingFormPost(t *testing.T) { |
||||
req := createFormPostRequest() |
||||
var obj FooBarStruct |
||||
FormPost.Bind(req, &obj) |
||||
|
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func TestBindingFormMultipart(t *testing.T) { |
||||
req := createFormMultipartRequest() |
||||
var obj FooBarStruct |
||||
FormMultipart.Bind(req, &obj) |
||||
|
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func TestValidationFails(t *testing.T) { |
||||
var obj FooStruct |
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func TestValidationDisabled(t *testing.T) { |
||||
backup := Validator |
||||
Validator = nil |
||||
defer func() { Validator = backup }() |
||||
|
||||
var obj FooStruct |
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestExistsSucceeds(t *testing.T) { |
||||
type HogeStruct struct { |
||||
Hoge *int `json:"hoge" binding:"exists"` |
||||
} |
||||
|
||||
var obj HogeStruct |
||||
req := requestWithBody("POST", "/", `{"hoge": 0}`) |
||||
err := JSON.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestFormDefaultValue(t *testing.T) { |
||||
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8" |
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q := new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 333, q.Int) |
||||
assert.Equal(t, "hello", q.String) |
||||
assert.Equal(t, true, q.Bool) |
||||
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice) |
||||
|
||||
params = "string=hello&bool=false" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "hello", q.String) |
||||
assert.Equal(t, false, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
|
||||
params = "strings=hello" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "default-string", q.String) |
||||
assert.Equal(t, false, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
|
||||
params = "int=&string=&bool=true&int64_slice=&int8_slice=" |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
q = new(ComplexDefaultStruct) |
||||
assert.NoError(t, Form.Bind(req, q)) |
||||
assert.Equal(t, 999, q.Int) |
||||
assert.Equal(t, "default-string", q.String) |
||||
assert.Equal(t, true, q.Bool) |
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||
} |
||||
|
||||
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||
b := Form |
||||
assert.Equal(t, b.Name(), "form") |
||||
|
||||
obj := FooBarStruct{} |
||||
req := requestWithBody(method, path, body) |
||||
if method == "POST" { |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
} |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
|
||||
obj = FooBarStruct{} |
||||
req = requestWithBody(method, badPath, badBody) |
||||
err = JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||
b := Query |
||||
assert.Equal(t, b.Name(), "query") |
||||
|
||||
obj := FooBarStruct{} |
||||
req := requestWithBody(method, path, body) |
||||
if method == "POST" { |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
} |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
assert.Equal(t, obj.Bar, "foo") |
||||
} |
||||
|
||||
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { |
||||
assert.Equal(t, b.Name(), name) |
||||
|
||||
obj := FooStruct{} |
||||
req := requestWithBody("POST", path, body) |
||||
err := b.Bind(req, &obj) |
||||
assert.NoError(t, err) |
||||
assert.Equal(t, obj.Foo, "bar") |
||||
|
||||
obj = FooStruct{} |
||||
req = requestWithBody("POST", badPath, badBody) |
||||
err = JSON.Bind(req, &obj) |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func requestWithBody(method, path, body string) (req *http.Request) { |
||||
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body)) |
||||
return |
||||
} |
||||
func BenchmarkBindingForm(b *testing.B) { |
||||
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w") |
||||
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||
f := Form |
||||
for i := 0; i < b.N; i++ { |
||||
obj := FooBarStruct{} |
||||
f.Bind(req, &obj) |
||||
} |
||||
} |
@ -0,0 +1,45 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"reflect" |
||||
"sync" |
||||
|
||||
"gopkg.in/go-playground/validator.v9" |
||||
) |
||||
|
||||
type defaultValidator struct { |
||||
once sync.Once |
||||
validate *validator.Validate |
||||
} |
||||
|
||||
var _ StructValidator = &defaultValidator{} |
||||
|
||||
func (v *defaultValidator) ValidateStruct(obj interface{}) error { |
||||
if kindOfData(obj) == reflect.Struct { |
||||
v.lazyinit() |
||||
if err := v.validate.Struct(obj); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error { |
||||
v.lazyinit() |
||||
return v.validate.RegisterValidation(key, fn) |
||||
} |
||||
|
||||
func (v *defaultValidator) lazyinit() { |
||||
v.once.Do(func() { |
||||
v.validate = validator.New() |
||||
}) |
||||
} |
||||
|
||||
func kindOfData(data interface{}) reflect.Kind { |
||||
value := reflect.ValueOf(data) |
||||
valueType := value.Kind() |
||||
if valueType == reflect.Ptr { |
||||
valueType = value.Elem().Kind() |
||||
} |
||||
return valueType |
||||
} |
@ -0,0 +1,113 @@ |
||||
// Code generated by protoc-gen-go.
|
||||
// source: test.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/* |
||||
Package example is a generated protocol buffer package. |
||||
|
||||
It is generated from these files: |
||||
test.proto |
||||
|
||||
It has these top-level messages: |
||||
Test |
||||
*/ |
||||
package example |
||||
|
||||
import proto "github.com/golang/protobuf/proto" |
||||
import math "math" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = math.Inf |
||||
|
||||
type FOO int32 |
||||
|
||||
const ( |
||||
FOO_X FOO = 17 |
||||
) |
||||
|
||||
var FOO_name = map[int32]string{ |
||||
17: "X", |
||||
} |
||||
var FOO_value = map[string]int32{ |
||||
"X": 17, |
||||
} |
||||
|
||||
func (x FOO) Enum() *FOO { |
||||
p := new(FOO) |
||||
*p = x |
||||
return p |
||||
} |
||||
func (x FOO) String() string { |
||||
return proto.EnumName(FOO_name, int32(x)) |
||||
} |
||||
func (x *FOO) UnmarshalJSON(data []byte) error { |
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
*x = FOO(value) |
||||
return nil |
||||
} |
||||
|
||||
type Test struct { |
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` |
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` |
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` |
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
} |
||||
|
||||
func (m *Test) Reset() { *m = Test{} } |
||||
func (m *Test) String() string { return proto.CompactTextString(m) } |
||||
func (*Test) ProtoMessage() {} |
||||
|
||||
const Default_Test_Type int32 = 77 |
||||
|
||||
func (m *Test) GetLabel() string { |
||||
if m != nil && m.Label != nil { |
||||
return *m.Label |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *Test) GetType() int32 { |
||||
if m != nil && m.Type != nil { |
||||
return *m.Type |
||||
} |
||||
return Default_Test_Type |
||||
} |
||||
|
||||
func (m *Test) GetReps() []int64 { |
||||
if m != nil { |
||||
return m.Reps |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup { |
||||
if m != nil { |
||||
return m.Optionalgroup |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
type Test_OptionalGroup struct { |
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` |
||||
XXX_unrecognized []byte `json:"-"` |
||||
} |
||||
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } |
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } |
||||
func (*Test_OptionalGroup) ProtoMessage() {} |
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string { |
||||
if m != nil && m.RequiredField != nil { |
||||
return *m.RequiredField |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value) |
||||
} |
@ -0,0 +1,12 @@ |
||||
package example; |
||||
|
||||
enum FOO {X=17;}; |
||||
|
||||
message Test { |
||||
required string label = 1; |
||||
optional int32 type = 2[default=77]; |
||||
repeated int64 reps = 3; |
||||
optional group OptionalGroup = 4{ |
||||
required string RequiredField = 5; |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"fmt" |
||||
"log" |
||||
"net/http" |
||||
) |
||||
|
||||
type Arg struct { |
||||
Max int64 `form:"max" validate:"max=10"` |
||||
Min int64 `form:"min" validate:"min=2"` |
||||
Range int64 `form:"range" validate:"min=1,max=10"` |
||||
// use split option to split arg 1,2,3 into slice [1 2 3]
|
||||
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
|
||||
Slice []int64 `form:"slice,split" validate:"min=1"` |
||||
} |
||||
|
||||
func ExampleBinding() { |
||||
req := initHTTP("max=9&min=3&range=3&slice=1,2,3") |
||||
arg := new(Arg) |
||||
if err := Form.Bind(req, arg); err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice) |
||||
// Output:
|
||||
// arg.Max 9
|
||||
// arg.Min 3
|
||||
// arg.Range 3
|
||||
// arg.Slice [1 2 3]
|
||||
} |
||||
|
||||
func initHTTP(params string) (req *http.Request) { |
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||
req.ParseForm() |
||||
return |
||||
} |
@ -0,0 +1,55 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const defaultMemory = 32 * 1024 * 1024 |
||||
|
||||
type formBinding struct{} |
||||
type formPostBinding struct{} |
||||
type formMultipartBinding struct{} |
||||
|
||||
func (f formBinding) Name() string { |
||||
return "form" |
||||
} |
||||
|
||||
func (f formBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseForm(); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.Form); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
||||
|
||||
func (f formPostBinding) Name() string { |
||||
return "form-urlencoded" |
||||
} |
||||
|
||||
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseForm(); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.PostForm); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
||||
|
||||
func (f formMultipartBinding) Name() string { |
||||
return "multipart/form-data" |
||||
} |
||||
|
||||
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error { |
||||
if err := req.ParseMultipartForm(defaultMemory); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
if err := mapForm(obj, req.MultipartForm.Value); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,276 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// scache struct reflect type cache.
|
||||
var scache = &cache{ |
||||
data: make(map[reflect.Type]*sinfo), |
||||
} |
||||
|
||||
type cache struct { |
||||
data map[reflect.Type]*sinfo |
||||
mutex sync.RWMutex |
||||
} |
||||
|
||||
func (c *cache) get(obj reflect.Type) (s *sinfo) { |
||||
var ok bool |
||||
c.mutex.RLock() |
||||
if s, ok = c.data[obj]; !ok { |
||||
c.mutex.RUnlock() |
||||
s = c.set(obj) |
||||
return |
||||
} |
||||
c.mutex.RUnlock() |
||||
return |
||||
} |
||||
|
||||
func (c *cache) set(obj reflect.Type) (s *sinfo) { |
||||
s = new(sinfo) |
||||
tp := obj.Elem() |
||||
for i := 0; i < tp.NumField(); i++ { |
||||
fd := new(field) |
||||
fd.tp = tp.Field(i) |
||||
tag := fd.tp.Tag.Get("form") |
||||
fd.name, fd.option = parseTag(tag) |
||||
if defV := fd.tp.Tag.Get("default"); defV != "" { |
||||
dv := reflect.New(fd.tp.Type).Elem() |
||||
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option) |
||||
fd.hasDefault = true |
||||
fd.defaultValue = dv |
||||
} |
||||
s.field = append(s.field, fd) |
||||
} |
||||
c.mutex.Lock() |
||||
c.data[obj] = s |
||||
c.mutex.Unlock() |
||||
return |
||||
} |
||||
|
||||
type sinfo struct { |
||||
field []*field |
||||
} |
||||
|
||||
type field struct { |
||||
tp reflect.StructField |
||||
name string |
||||
option tagOptions |
||||
|
||||
hasDefault bool // if field had default value
|
||||
defaultValue reflect.Value // field default value
|
||||
} |
||||
|
||||
func mapForm(ptr interface{}, form map[string][]string) error { |
||||
sinfo := scache.get(reflect.TypeOf(ptr)) |
||||
val := reflect.ValueOf(ptr).Elem() |
||||
for i, fd := range sinfo.field { |
||||
typeField := fd.tp |
||||
structField := val.Field(i) |
||||
if !structField.CanSet() { |
||||
continue |
||||
} |
||||
|
||||
structFieldKind := structField.Kind() |
||||
inputFieldName := fd.name |
||||
if inputFieldName == "" { |
||||
inputFieldName = typeField.Name |
||||
|
||||
// if "form" tag is nil, we inspect if the field is a struct.
|
||||
// this would not make sense for JSON parsing but it does for a form
|
||||
// since data is flatten
|
||||
if structFieldKind == reflect.Struct { |
||||
err := mapForm(structField.Addr().Interface(), form) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
continue |
||||
} |
||||
} |
||||
inputValue, exists := form[inputFieldName] |
||||
if !exists { |
||||
// Set the field as default value when the input value is not exist
|
||||
if fd.hasDefault { |
||||
structField.Set(fd.defaultValue) |
||||
} |
||||
continue |
||||
} |
||||
// Set the field as default value when the input value is empty
|
||||
if fd.hasDefault && inputValue[0] == "" { |
||||
structField.Set(fd.defaultValue) |
||||
continue |
||||
} |
||||
if _, isTime := structField.Interface().(time.Time); isTime { |
||||
if err := setTimeField(inputValue[0], typeField, structField); err != nil { |
||||
return err |
||||
} |
||||
continue |
||||
} |
||||
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error { |
||||
switch valueKind { |
||||
case reflect.Int: |
||||
return setIntField(val[0], 0, structField) |
||||
case reflect.Int8: |
||||
return setIntField(val[0], 8, structField) |
||||
case reflect.Int16: |
||||
return setIntField(val[0], 16, structField) |
||||
case reflect.Int32: |
||||
return setIntField(val[0], 32, structField) |
||||
case reflect.Int64: |
||||
return setIntField(val[0], 64, structField) |
||||
case reflect.Uint: |
||||
return setUintField(val[0], 0, structField) |
||||
case reflect.Uint8: |
||||
return setUintField(val[0], 8, structField) |
||||
case reflect.Uint16: |
||||
return setUintField(val[0], 16, structField) |
||||
case reflect.Uint32: |
||||
return setUintField(val[0], 32, structField) |
||||
case reflect.Uint64: |
||||
return setUintField(val[0], 64, structField) |
||||
case reflect.Bool: |
||||
return setBoolField(val[0], structField) |
||||
case reflect.Float32: |
||||
return setFloatField(val[0], 32, structField) |
||||
case reflect.Float64: |
||||
return setFloatField(val[0], 64, structField) |
||||
case reflect.String: |
||||
structField.SetString(val[0]) |
||||
case reflect.Slice: |
||||
if option.Contains("split") { |
||||
val = strings.Split(val[0], ",") |
||||
} |
||||
filtered := filterEmpty(val) |
||||
switch structField.Type().Elem().Kind() { |
||||
case reflect.Int64: |
||||
valSli := make([]int64, 0, len(filtered)) |
||||
for i := 0; i < len(filtered); i++ { |
||||
d, err := strconv.ParseInt(filtered[i], 10, 64) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
valSli = append(valSli, d) |
||||
} |
||||
structField.Set(reflect.ValueOf(valSli)) |
||||
case reflect.String: |
||||
valSli := make([]string, 0, len(filtered)) |
||||
for i := 0; i < len(filtered); i++ { |
||||
valSli = append(valSli, filtered[i]) |
||||
} |
||||
structField.Set(reflect.ValueOf(valSli)) |
||||
default: |
||||
sliceOf := structField.Type().Elem().Kind() |
||||
numElems := len(filtered) |
||||
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered)) |
||||
for i := 0; i < numElems; i++ { |
||||
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
structField.Set(slice) |
||||
} |
||||
default: |
||||
return errors.New("Unknown type") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setIntField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0" |
||||
} |
||||
intVal, err := strconv.ParseInt(val, 10, bitSize) |
||||
if err == nil { |
||||
field.SetInt(intVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setUintField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0" |
||||
} |
||||
uintVal, err := strconv.ParseUint(val, 10, bitSize) |
||||
if err == nil { |
||||
field.SetUint(uintVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setBoolField(val string, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "false" |
||||
} |
||||
boolVal, err := strconv.ParseBool(val) |
||||
if err == nil { |
||||
field.SetBool(boolVal) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func setFloatField(val string, bitSize int, field reflect.Value) error { |
||||
if val == "" { |
||||
val = "0.0" |
||||
} |
||||
floatVal, err := strconv.ParseFloat(val, bitSize) |
||||
if err == nil { |
||||
field.SetFloat(floatVal) |
||||
} |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { |
||||
timeFormat := structField.Tag.Get("time_format") |
||||
if timeFormat == "" { |
||||
return errors.New("Blank time format") |
||||
} |
||||
|
||||
if val == "" { |
||||
value.Set(reflect.ValueOf(time.Time{})) |
||||
return nil |
||||
} |
||||
|
||||
l := time.Local |
||||
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { |
||||
l = time.UTC |
||||
} |
||||
|
||||
if locTag := structField.Tag.Get("time_location"); locTag != "" { |
||||
loc, err := time.LoadLocation(locTag) |
||||
if err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
l = loc |
||||
} |
||||
|
||||
t, err := time.ParseInLocation(timeFormat, val, l) |
||||
if err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
|
||||
value.Set(reflect.ValueOf(t)) |
||||
return nil |
||||
} |
||||
|
||||
func filterEmpty(val []string) []string { |
||||
filtered := make([]string, 0, len(val)) |
||||
for _, v := range val { |
||||
if v != "" { |
||||
filtered = append(filtered, v) |
||||
} |
||||
} |
||||
return filtered |
||||
} |
@ -0,0 +1,22 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
type jsonBinding struct{} |
||||
|
||||
func (jsonBinding) Name() string { |
||||
return "json" |
||||
} |
||||
|
||||
func (jsonBinding) Bind(req *http.Request, obj interface{}) error { |
||||
decoder := json.NewDecoder(req.Body) |
||||
if err := decoder.Decode(obj); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,19 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"net/http" |
||||
) |
||||
|
||||
type queryBinding struct{} |
||||
|
||||
func (queryBinding) Name() string { |
||||
return "query" |
||||
} |
||||
|
||||
func (queryBinding) Bind(req *http.Request, obj interface{}) error { |
||||
values := req.URL.Query() |
||||
if err := mapForm(obj, values); err != nil { |
||||
return err |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,44 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binding |
||||
|
||||
import ( |
||||
"strings" |
||||
) |
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string |
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) { |
||||
if idx := strings.Index(tag, ","); idx != -1 { |
||||
return tag[:idx], tagOptions(tag[idx+1:]) |
||||
} |
||||
return tag, tagOptions("") |
||||
} |
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool { |
||||
if len(o) == 0 { |
||||
return false |
||||
} |
||||
s := string(o) |
||||
for s != "" { |
||||
var next string |
||||
i := strings.Index(s, ",") |
||||
if i >= 0 { |
||||
s, next = s[:i], s[i+1:] |
||||
} |
||||
if s == optionName { |
||||
return true |
||||
} |
||||
s = next |
||||
} |
||||
return false |
||||
} |
@ -0,0 +1,209 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"bytes" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
type testInterface interface { |
||||
String() string |
||||
} |
||||
|
||||
type substructNoValidation struct { |
||||
IString string |
||||
IInt int |
||||
} |
||||
|
||||
type mapNoValidationSub map[string]substructNoValidation |
||||
|
||||
type structNoValidationValues struct { |
||||
substructNoValidation |
||||
|
||||
Boolean bool |
||||
|
||||
Uinteger uint |
||||
Integer int |
||||
Integer8 int8 |
||||
Integer16 int16 |
||||
Integer32 int32 |
||||
Integer64 int64 |
||||
Uinteger8 uint8 |
||||
Uinteger16 uint16 |
||||
Uinteger32 uint32 |
||||
Uinteger64 uint64 |
||||
|
||||
Float32 float32 |
||||
Float64 float64 |
||||
|
||||
String string |
||||
|
||||
Date time.Time |
||||
|
||||
Struct substructNoValidation |
||||
InlinedStruct struct { |
||||
String []string |
||||
Integer int |
||||
} |
||||
|
||||
IntSlice []int |
||||
IntPointerSlice []*int |
||||
StructPointerSlice []*substructNoValidation |
||||
StructSlice []substructNoValidation |
||||
InterfaceSlice []testInterface |
||||
|
||||
UniversalInterface interface{} |
||||
CustomInterface testInterface |
||||
|
||||
FloatMap map[string]float32 |
||||
StructMap mapNoValidationSub |
||||
} |
||||
|
||||
func createNoValidationValues() structNoValidationValues { |
||||
integer := 1 |
||||
s := structNoValidationValues{ |
||||
Boolean: true, |
||||
Uinteger: 1 << 29, |
||||
Integer: -10000, |
||||
Integer8: 120, |
||||
Integer16: -20000, |
||||
Integer32: 1 << 29, |
||||
Integer64: 1 << 61, |
||||
Uinteger8: 250, |
||||
Uinteger16: 50000, |
||||
Uinteger32: 1 << 31, |
||||
Uinteger64: 1 << 62, |
||||
Float32: 123.456, |
||||
Float64: 123.456789, |
||||
String: "text", |
||||
Date: time.Time{}, |
||||
CustomInterface: &bytes.Buffer{}, |
||||
Struct: substructNoValidation{}, |
||||
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3}, |
||||
IntPointerSlice: []*int{&integer}, |
||||
StructSlice: []substructNoValidation{}, |
||||
UniversalInterface: 1.2, |
||||
FloatMap: map[string]float32{ |
||||
"foo": 1.23, |
||||
"bar": 232.323, |
||||
}, |
||||
StructMap: mapNoValidationSub{ |
||||
"foo": substructNoValidation{}, |
||||
"bar": substructNoValidation{}, |
||||
}, |
||||
// StructPointerSlice []noValidationSub
|
||||
// InterfaceSlice []testInterface
|
||||
} |
||||
s.InlinedStruct.Integer = 1000 |
||||
s.InlinedStruct.String = []string{"first", "second"} |
||||
s.IString = "substring" |
||||
s.IInt = 987654 |
||||
return s |
||||
} |
||||
|
||||
func TestValidateNoValidationValues(t *testing.T) { |
||||
origin := createNoValidationValues() |
||||
test := createNoValidationValues() |
||||
empty := structNoValidationValues{} |
||||
|
||||
assert.Nil(t, validate(test)) |
||||
assert.Nil(t, validate(&test)) |
||||
assert.Nil(t, validate(empty)) |
||||
assert.Nil(t, validate(&empty)) |
||||
|
||||
assert.Equal(t, origin, test) |
||||
} |
||||
|
||||
type structNoValidationPointer struct { |
||||
// substructNoValidation
|
||||
|
||||
Boolean bool |
||||
|
||||
Uinteger *uint |
||||
Integer *int |
||||
Integer8 *int8 |
||||
Integer16 *int16 |
||||
Integer32 *int32 |
||||
Integer64 *int64 |
||||
Uinteger8 *uint8 |
||||
Uinteger16 *uint16 |
||||
Uinteger32 *uint32 |
||||
Uinteger64 *uint64 |
||||
|
||||
Float32 *float32 |
||||
Float64 *float64 |
||||
|
||||
String *string |
||||
|
||||
Date *time.Time |
||||
|
||||
Struct *substructNoValidation |
||||
|
||||
IntSlice *[]int |
||||
IntPointerSlice *[]*int |
||||
StructPointerSlice *[]*substructNoValidation |
||||
StructSlice *[]substructNoValidation |
||||
InterfaceSlice *[]testInterface |
||||
|
||||
FloatMap *map[string]float32 |
||||
StructMap *mapNoValidationSub |
||||
} |
||||
|
||||
func TestValidateNoValidationPointers(t *testing.T) { |
||||
//origin := createNoValidation_values()
|
||||
//test := createNoValidation_values()
|
||||
empty := structNoValidationPointer{} |
||||
|
||||
//assert.Nil(t, validate(test))
|
||||
//assert.Nil(t, validate(&test))
|
||||
assert.Nil(t, validate(empty)) |
||||
assert.Nil(t, validate(&empty)) |
||||
|
||||
//assert.Equal(t, origin, test)
|
||||
} |
||||
|
||||
type Object map[string]interface{} |
||||
|
||||
func TestValidatePrimitives(t *testing.T) { |
||||
obj := Object{"foo": "bar", "bar": 1} |
||||
assert.NoError(t, validate(obj)) |
||||
assert.NoError(t, validate(&obj)) |
||||
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1}) |
||||
|
||||
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} |
||||
assert.NoError(t, validate(obj2)) |
||||
assert.NoError(t, validate(&obj2)) |
||||
|
||||
nu := 10 |
||||
assert.NoError(t, validate(nu)) |
||||
assert.NoError(t, validate(&nu)) |
||||
assert.Equal(t, nu, 10) |
||||
|
||||
str := "value" |
||||
assert.NoError(t, validate(str)) |
||||
assert.NoError(t, validate(&str)) |
||||
assert.Equal(t, str, "value") |
||||
} |
||||
|
||||
// structCustomValidation is a helper struct we use to check that
|
||||
// custom validation can be registered on it.
|
||||
// The `notone` binding directive is for custom validation and registered later.
|
||||
// type structCustomValidation struct {
|
||||
// Integer int `binding:"notone"`
|
||||
// }
|
||||
|
||||
// notOne is a custom validator meant to be used with `validator.v8` library.
|
||||
// The method signature for `v9` is significantly different and this function
|
||||
// would need to be changed for tests to pass after upgrade.
|
||||
// See https://github.com/gin-gonic/gin/pull/1015.
|
||||
// func notOne(
|
||||
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
|
||||
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
|
||||
// ) bool {
|
||||
// if val, ok := field.Interface().(int); ok {
|
||||
// return val != 1
|
||||
// }
|
||||
// return false
|
||||
// }
|
@ -0,0 +1,22 @@ |
||||
package binding |
||||
|
||||
import ( |
||||
"encoding/xml" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
type xmlBinding struct{} |
||||
|
||||
func (xmlBinding) Name() string { |
||||
return "xml" |
||||
} |
||||
|
||||
func (xmlBinding) Bind(req *http.Request, obj interface{}) error { |
||||
decoder := xml.NewDecoder(req.Body) |
||||
if err := decoder.Decode(obj); err != nil { |
||||
return errors.WithStack(err) |
||||
} |
||||
return validate(obj) |
||||
} |
@ -0,0 +1,306 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"context" |
||||
"math" |
||||
"net/http" |
||||
"strconv" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/ecode" |
||||
"github.com/bilibili/Kratos/pkg/net/http/blademaster/binding" |
||||
"github.com/bilibili/Kratos/pkg/net/http/blademaster/render" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
"github.com/gogo/protobuf/types" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
_abortIndex int8 = math.MaxInt8 / 2 |
||||
) |
||||
|
||||
var ( |
||||
_openParen = []byte("(") |
||||
_closeParen = []byte(")") |
||||
) |
||||
|
||||
// Context is the most important part. It allows us to pass variables between
|
||||
// middleware, manage the flow, validate the JSON of a request and render a
|
||||
// JSON response for example.
|
||||
type Context struct { |
||||
context.Context |
||||
|
||||
Request *http.Request |
||||
Writer http.ResponseWriter |
||||
|
||||
// flow control
|
||||
index int8 |
||||
handlers []HandlerFunc |
||||
|
||||
// Keys is a key/value pair exclusively for the context of each request.
|
||||
Keys map[string]interface{} |
||||
|
||||
Error error |
||||
|
||||
method string |
||||
engine *Engine |
||||
} |
||||
|
||||
/************************************/ |
||||
/*********** FLOW CONTROL ***********/ |
||||
/************************************/ |
||||
|
||||
// Next should be used only inside middleware.
|
||||
// It executes the pending handlers in the chain inside the calling handler.
|
||||
// See example in godoc.
|
||||
func (c *Context) Next() { |
||||
c.index++ |
||||
s := int8(len(c.handlers)) |
||||
for ; c.index < s; c.index++ { |
||||
// only check method on last handler, otherwise middlewares
|
||||
// will never be effected if request method is not matched
|
||||
if c.index == s-1 && c.method != c.Request.Method { |
||||
code := http.StatusMethodNotAllowed |
||||
c.Error = ecode.MethodNotAllowed |
||||
http.Error(c.Writer, http.StatusText(code), code) |
||||
return |
||||
} |
||||
|
||||
c.handlers[c.index](c) |
||||
} |
||||
} |
||||
|
||||
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
|
||||
// Let's say you have an authorization middleware that validates that the current request is authorized.
|
||||
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
|
||||
// for this request are not called.
|
||||
func (c *Context) Abort() { |
||||
c.index = _abortIndex |
||||
} |
||||
|
||||
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
|
||||
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
|
||||
func (c *Context) AbortWithStatus(code int) { |
||||
c.Status(code) |
||||
c.Abort() |
||||
} |
||||
|
||||
// IsAborted returns true if the current context was aborted.
|
||||
func (c *Context) IsAborted() bool { |
||||
return c.index >= _abortIndex |
||||
} |
||||
|
||||
/************************************/ |
||||
/******** METADATA MANAGEMENT********/ |
||||
/************************************/ |
||||
|
||||
// Set is used to store a new key/value pair exclusively for this context.
|
||||
// It also lazy initializes c.Keys if it was not used previously.
|
||||
func (c *Context) Set(key string, value interface{}) { |
||||
if c.Keys == nil { |
||||
c.Keys = make(map[string]interface{}) |
||||
} |
||||
c.Keys[key] = value |
||||
} |
||||
|
||||
// Get returns the value for the given key, ie: (value, true).
|
||||
// If the value does not exists it returns (nil, false)
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) { |
||||
value, exists = c.Keys[key] |
||||
return |
||||
} |
||||
|
||||
/************************************/ |
||||
/******** RESPONSE RENDERING ********/ |
||||
/************************************/ |
||||
|
||||
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
|
||||
func bodyAllowedForStatus(status int) bool { |
||||
switch { |
||||
case status >= 100 && status <= 199: |
||||
return false |
||||
case status == 204: |
||||
return false |
||||
case status == 304: |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// Status sets the HTTP response code.
|
||||
func (c *Context) Status(code int) { |
||||
c.Writer.WriteHeader(code) |
||||
} |
||||
|
||||
// Render http response with http code by a render instance.
|
||||
func (c *Context) Render(code int, r render.Render) { |
||||
r.WriteContentType(c.Writer) |
||||
if code > 0 { |
||||
c.Status(code) |
||||
} |
||||
|
||||
if !bodyAllowedForStatus(code) { |
||||
return |
||||
} |
||||
|
||||
params := c.Request.Form |
||||
|
||||
cb := params.Get("callback") |
||||
jsonp := cb != "" && params.Get("jsonp") == "jsonp" |
||||
if jsonp { |
||||
c.Writer.Write([]byte(cb)) |
||||
c.Writer.Write(_openParen) |
||||
} |
||||
|
||||
if err := r.Render(c.Writer); err != nil { |
||||
c.Error = err |
||||
return |
||||
} |
||||
|
||||
if jsonp { |
||||
if _, err := c.Writer.Write(_closeParen); err != nil { |
||||
c.Error = errors.WithStack(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// JSON serializes the given struct as JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSON(data interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.JSON{ |
||||
Code: bcode.Code(), |
||||
Message: bcode.Message(), |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// JSONMap serializes the given map as map JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSONMap(data map[string]interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
data["code"] = bcode.Code() |
||||
if _, ok := data["message"]; !ok { |
||||
data["message"] = bcode.Message() |
||||
} |
||||
c.Render(code, render.MapJSON(data)) |
||||
} |
||||
|
||||
// XML serializes the given struct as XML into the response body.
|
||||
// It also sets the Content-Type as "application/xml".
|
||||
func (c *Context) XML(data interface{}, err error) { |
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
// TODO app allow 5xx?
|
||||
/* |
||||
if bcode.Code() == -500 { |
||||
code = http.StatusServiceUnavailable |
||||
} |
||||
*/ |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.XML{ |
||||
Code: bcode.Code(), |
||||
Message: bcode.Message(), |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// Protobuf serializes the given struct as PB into the response body.
|
||||
// It also sets the ContentType as "application/x-protobuf".
|
||||
func (c *Context) Protobuf(data proto.Message, err error) { |
||||
var ( |
||||
bytes []byte |
||||
) |
||||
|
||||
code := http.StatusOK |
||||
c.Error = err |
||||
bcode := ecode.Cause(err) |
||||
|
||||
any := new(types.Any) |
||||
if data != nil { |
||||
if bytes, err = proto.Marshal(data); err != nil { |
||||
c.Error = errors.WithStack(err) |
||||
return |
||||
} |
||||
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data) |
||||
any.Value = bytes |
||||
} |
||||
writeStatusCode(c.Writer, bcode.Code()) |
||||
c.Render(code, render.PB{ |
||||
Code: int64(bcode.Code()), |
||||
Message: bcode.Message(), |
||||
Data: any, |
||||
}) |
||||
} |
||||
|
||||
// Bytes writes some data into the body stream and updates the HTTP code.
|
||||
func (c *Context) Bytes(code int, contentType string, data ...[]byte) { |
||||
c.Render(code, render.Data{ |
||||
ContentType: contentType, |
||||
Data: data, |
||||
}) |
||||
} |
||||
|
||||
// String writes the given string into the response body.
|
||||
func (c *Context) String(code int, format string, values ...interface{}) { |
||||
c.Render(code, render.String{Format: format, Data: values}) |
||||
} |
||||
|
||||
// Redirect returns a HTTP redirect to the specific location.
|
||||
func (c *Context) Redirect(code int, location string) { |
||||
c.Render(-1, render.Redirect{ |
||||
Code: code, |
||||
Location: location, |
||||
Request: c.Request, |
||||
}) |
||||
} |
||||
|
||||
// BindWith bind req arg with parser.
|
||||
func (c *Context) BindWith(obj interface{}, b binding.Binding) error { |
||||
return c.mustBindWith(obj, b) |
||||
} |
||||
|
||||
// Bind bind req arg with defult form binding.
|
||||
func (c *Context) Bind(obj interface{}) error { |
||||
return c.mustBindWith(obj, binding.Form) |
||||
} |
||||
|
||||
// mustBindWith binds the passed struct pointer using the specified binding engine.
|
||||
// It will abort the request with HTTP 400 if any error ocurrs.
|
||||
// See the binding package.
|
||||
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) { |
||||
if err = b.Bind(c.Request, obj); err != nil { |
||||
c.Error = ecode.RequestErr |
||||
c.Render(http.StatusOK, render.JSON{ |
||||
Code: ecode.RequestErr.Code(), |
||||
Message: err.Error(), |
||||
Data: nil, |
||||
}) |
||||
c.Abort() |
||||
} |
||||
return |
||||
} |
||||
|
||||
func writeStatusCode(w http.ResponseWriter, ecode int) { |
||||
header := w.Header() |
||||
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10)) |
||||
} |
@ -0,0 +1,249 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"net/http" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// CORSConfig represents all available options for the middleware.
|
||||
type CORSConfig struct { |
||||
AllowAllOrigins bool |
||||
|
||||
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// Default value is []
|
||||
AllowOrigins []string |
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It take the origin
|
||||
// as argument and returns true if allowed or false otherwise. If this option is
|
||||
// set, the content of AllowedOrigins is ignored.
|
||||
AllowOriginFunc func(origin string) bool |
||||
|
||||
// AllowedMethods is a list of methods the client is allowed to use with
|
||||
// cross-domain requests. Default value is simple methods (GET and POST)
|
||||
AllowMethods []string |
||||
|
||||
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||
// cross-domain requests.
|
||||
AllowHeaders []string |
||||
|
||||
// AllowCredentials indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
AllowCredentials bool |
||||
|
||||
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||
// API specification
|
||||
ExposeHeaders []string |
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached
|
||||
MaxAge time.Duration |
||||
} |
||||
|
||||
type cors struct { |
||||
allowAllOrigins bool |
||||
allowCredentials bool |
||||
allowOriginFunc func(string) bool |
||||
allowOrigins []string |
||||
normalHeaders http.Header |
||||
preflightHeaders http.Header |
||||
} |
||||
|
||||
type converter func(string) string |
||||
|
||||
// Validate is check configuration of user defined.
|
||||
func (c *CORSConfig) Validate() error { |
||||
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { |
||||
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") |
||||
} |
||||
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { |
||||
return errors.New("conflict settings: all origins disabled") |
||||
} |
||||
for _, origin := range c.AllowOrigins { |
||||
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { |
||||
return errors.New("bad origin: origins must either be '*' or include http:// or https://") |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// CORS returns the location middleware with default configuration.
|
||||
func CORS(allowOriginHosts []string) HandlerFunc { |
||||
config := &CORSConfig{ |
||||
AllowMethods: []string{"GET", "POST"}, |
||||
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"}, |
||||
AllowCredentials: true, |
||||
MaxAge: time.Duration(0), |
||||
AllowOriginFunc: func(origin string) bool { |
||||
for _, host := range allowOriginHosts { |
||||
if strings.HasSuffix(strings.ToLower(origin), host) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
}, |
||||
} |
||||
return newCORS(config) |
||||
} |
||||
|
||||
// newCORS returns the location middleware with user-defined custom configuration.
|
||||
func newCORS(config *CORSConfig) HandlerFunc { |
||||
if err := config.Validate(); err != nil { |
||||
panic(err.Error()) |
||||
} |
||||
cors := &cors{ |
||||
allowOriginFunc: config.AllowOriginFunc, |
||||
allowAllOrigins: config.AllowAllOrigins, |
||||
allowCredentials: config.AllowCredentials, |
||||
allowOrigins: normalize(config.AllowOrigins), |
||||
normalHeaders: generateNormalHeaders(config), |
||||
preflightHeaders: generatePreflightHeaders(config), |
||||
} |
||||
|
||||
return func(c *Context) { |
||||
cors.applyCORS(c) |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) applyCORS(c *Context) { |
||||
origin := c.Request.Header.Get("Origin") |
||||
if len(origin) == 0 { |
||||
// request is not a CORS request
|
||||
return |
||||
} |
||||
if !cors.validateOrigin(origin) { |
||||
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin) |
||||
c.AbortWithStatus(http.StatusForbidden) |
||||
return |
||||
} |
||||
|
||||
if c.Request.Method == "OPTIONS" { |
||||
cors.handlePreflight(c) |
||||
defer c.AbortWithStatus(200) |
||||
} else { |
||||
cors.handleNormal(c) |
||||
} |
||||
|
||||
if !cors.allowAllOrigins { |
||||
header := c.Writer.Header() |
||||
header.Set("Access-Control-Allow-Origin", origin) |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) validateOrigin(origin string) bool { |
||||
if cors.allowAllOrigins { |
||||
return true |
||||
} |
||||
for _, value := range cors.allowOrigins { |
||||
if value == origin { |
||||
return true |
||||
} |
||||
} |
||||
if cors.allowOriginFunc != nil { |
||||
return cors.allowOriginFunc(origin) |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func (cors *cors) handlePreflight(c *Context) { |
||||
header := c.Writer.Header() |
||||
for key, value := range cors.preflightHeaders { |
||||
header[key] = value |
||||
} |
||||
} |
||||
|
||||
func (cors *cors) handleNormal(c *Context) { |
||||
header := c.Writer.Header() |
||||
for key, value := range cors.normalHeaders { |
||||
header[key] = value |
||||
} |
||||
} |
||||
|
||||
func generateNormalHeaders(c *CORSConfig) http.Header { |
||||
headers := make(http.Header) |
||||
if c.AllowCredentials { |
||||
headers.Set("Access-Control-Allow-Credentials", "true") |
||||
} |
||||
|
||||
// backport support for early browsers
|
||||
if len(c.AllowMethods) > 0 { |
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||
value := strings.Join(allowMethods, ",") |
||||
headers.Set("Access-Control-Allow-Methods", value) |
||||
} |
||||
|
||||
if len(c.ExposeHeaders) > 0 { |
||||
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) |
||||
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) |
||||
} |
||||
if c.AllowAllOrigins { |
||||
headers.Set("Access-Control-Allow-Origin", "*") |
||||
} else { |
||||
headers.Set("Vary", "Origin") |
||||
} |
||||
return headers |
||||
} |
||||
|
||||
func generatePreflightHeaders(c *CORSConfig) http.Header { |
||||
headers := make(http.Header) |
||||
if c.AllowCredentials { |
||||
headers.Set("Access-Control-Allow-Credentials", "true") |
||||
} |
||||
if len(c.AllowMethods) > 0 { |
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||
value := strings.Join(allowMethods, ",") |
||||
headers.Set("Access-Control-Allow-Methods", value) |
||||
} |
||||
if len(c.AllowHeaders) > 0 { |
||||
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) |
||||
value := strings.Join(allowHeaders, ",") |
||||
headers.Set("Access-Control-Allow-Headers", value) |
||||
} |
||||
if c.MaxAge > time.Duration(0) { |
||||
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) |
||||
headers.Set("Access-Control-Max-Age", value) |
||||
} |
||||
if c.AllowAllOrigins { |
||||
headers.Set("Access-Control-Allow-Origin", "*") |
||||
} else { |
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
|
||||
headers.Add("Vary", "Origin") |
||||
headers.Add("Vary", "Access-Control-Request-Method") |
||||
headers.Add("Vary", "Access-Control-Request-Headers") |
||||
} |
||||
return headers |
||||
} |
||||
|
||||
func normalize(values []string) []string { |
||||
if values == nil { |
||||
return nil |
||||
} |
||||
distinctMap := make(map[string]bool, len(values)) |
||||
normalized := make([]string, 0, len(values)) |
||||
for _, value := range values { |
||||
value = strings.TrimSpace(value) |
||||
value = strings.ToLower(value) |
||||
if _, seen := distinctMap[value]; !seen { |
||||
normalized = append(normalized, value) |
||||
distinctMap[value] = true |
||||
} |
||||
} |
||||
return normalized |
||||
} |
||||
|
||||
func convert(s []string, c converter) []string { |
||||
var out []string |
||||
for _, i := range s { |
||||
out = append(out, c(i)) |
||||
} |
||||
return out |
||||
} |
@ -0,0 +1,69 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"net/url" |
||||
"regexp" |
||||
"strings" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
) |
||||
|
||||
func matchHostSuffix(suffix string) func(*url.URL) bool { |
||||
return func(uri *url.URL) bool { |
||||
return strings.HasSuffix(strings.ToLower(uri.Host), suffix) |
||||
} |
||||
} |
||||
|
||||
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool { |
||||
return func(uri *url.URL) bool { |
||||
return pattern.MatchString(strings.ToLower(uri.String())) |
||||
} |
||||
} |
||||
|
||||
// CSRF returns the csrf middleware to prevent invalid cross site request.
|
||||
// Only referer is checked currently.
|
||||
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc { |
||||
validations := []func(*url.URL) bool{} |
||||
|
||||
addHostSuffix := func(suffix string) { |
||||
validations = append(validations, matchHostSuffix(suffix)) |
||||
} |
||||
addPattern := func(pattern string) { |
||||
validations = append(validations, matchPattern(regexp.MustCompile(pattern))) |
||||
} |
||||
|
||||
for _, r := range allowHosts { |
||||
addHostSuffix(r) |
||||
} |
||||
for _, p := range allowPattern { |
||||
addPattern(p) |
||||
} |
||||
|
||||
return func(c *Context) { |
||||
referer := c.Request.Header.Get("Referer") |
||||
params := c.Request.Form |
||||
cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "") |
||||
if referer == "" { |
||||
if !cross { |
||||
return |
||||
} |
||||
log.V(5).Info("The request's Referer header is empty.") |
||||
c.AbortWithStatus(403) |
||||
return |
||||
} |
||||
illegal := true |
||||
if uri, err := url.Parse(referer); err == nil && uri.Host != "" { |
||||
for _, validate := range validations { |
||||
if validate(uri) { |
||||
illegal = false |
||||
break |
||||
} |
||||
} |
||||
} |
||||
if illegal { |
||||
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer) |
||||
c.AbortWithStatus(403) |
||||
return |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,69 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strconv" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/ecode" |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/net/metadata" |
||||
) |
||||
|
||||
// Logger is logger middleware
|
||||
func Logger() HandlerFunc { |
||||
const noUser = "no_user" |
||||
return func(c *Context) { |
||||
now := time.Now() |
||||
ip := metadata.String(c, metadata.RemoteIP) |
||||
req := c.Request |
||||
path := req.URL.Path |
||||
params := req.Form |
||||
var quota float64 |
||||
if deadline, ok := c.Context.Deadline(); ok { |
||||
quota = time.Until(deadline).Seconds() |
||||
} |
||||
|
||||
c.Next() |
||||
|
||||
err := c.Error |
||||
cerr := ecode.Cause(err) |
||||
dt := time.Since(now) |
||||
caller := metadata.String(c, metadata.Caller) |
||||
if caller == "" { |
||||
caller = noUser |
||||
} |
||||
|
||||
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10)) |
||||
stats.Timing(caller, int64(dt/time.Millisecond), path[1:]) |
||||
|
||||
lf := log.Infov |
||||
errmsg := "" |
||||
isSlow := dt >= (time.Millisecond * 500) |
||||
if err != nil { |
||||
errmsg = err.Error() |
||||
lf = log.Errorv |
||||
if cerr.Code() > 0 { |
||||
lf = log.Warnv |
||||
} |
||||
} else { |
||||
if isSlow { |
||||
lf = log.Warnv |
||||
} |
||||
} |
||||
lf(c, |
||||
log.KVString("method", req.Method), |
||||
log.KVString("ip", ip), |
||||
log.KVString("user", caller), |
||||
log.KVString("path", path), |
||||
log.KVString("params", params.Encode()), |
||||
log.KVInt("ret", cerr.Code()), |
||||
log.KVString("msg", cerr.Message()), |
||||
log.KVString("stack", fmt.Sprintf("%+v", err)), |
||||
log.KVString("err", errmsg), |
||||
log.KVFloat64("timeout_quota", quota), |
||||
log.KVFloat64("ts", dt.Seconds()), |
||||
log.KVString("source", "http-access-log"), |
||||
) |
||||
} |
||||
} |
@ -0,0 +1,46 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"flag" |
||||
"net/http" |
||||
"net/http/pprof" |
||||
"os" |
||||
"sync" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/conf/dsn" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var ( |
||||
_perfOnce sync.Once |
||||
_perfDSN string |
||||
) |
||||
|
||||
func init() { |
||||
v := os.Getenv("HTTP_PERF") |
||||
if v == "" { |
||||
v = "tcp://0.0.0.0:2333" |
||||
} |
||||
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.") |
||||
} |
||||
|
||||
func startPerf() { |
||||
_perfOnce.Do(func() { |
||||
mux := http.NewServeMux() |
||||
mux.HandleFunc("/debug/pprof/", pprof.Index) |
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) |
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile) |
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) |
||||
|
||||
go func() { |
||||
d, err := dsn.Parse(_perfDSN) |
||||
if err != nil { |
||||
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err)) |
||||
} |
||||
if err := http.ListenAndServe(d.Host, mux); err != nil { |
||||
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err)) |
||||
} |
||||
}() |
||||
}) |
||||
} |
@ -0,0 +1,12 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||
) |
||||
|
||||
func monitor() HandlerFunc { |
||||
return func(c *Context) { |
||||
h := promhttp.Handler() |
||||
h.ServeHTTP(c.Writer, c.Request) |
||||
} |
||||
} |
@ -0,0 +1,32 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http/httputil" |
||||
"os" |
||||
"runtime" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
) |
||||
|
||||
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||
func Recovery() HandlerFunc { |
||||
return func(c *Context) { |
||||
defer func() { |
||||
var rawReq []byte |
||||
if err := recover(); err != nil { |
||||
const size = 64 << 10 |
||||
buf := make([]byte, size) |
||||
buf = buf[:runtime.Stack(buf, false)] |
||||
if c.Request != nil { |
||||
rawReq, _ = httputil.DumpRequest(c.Request, false) |
||||
} |
||||
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf) |
||||
fmt.Fprintf(os.Stderr, pl) |
||||
log.Error(pl) |
||||
c.AbortWithStatus(500) |
||||
} |
||||
}() |
||||
c.Next() |
||||
} |
||||
} |
@ -0,0 +1,30 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// Data common bytes struct.
|
||||
type Data struct { |
||||
ContentType string |
||||
Data [][]byte |
||||
} |
||||
|
||||
// Render (Data) writes data with custom ContentType.
|
||||
func (r Data) Render(w http.ResponseWriter) (err error) { |
||||
r.WriteContentType(w) |
||||
for _, d := range r.Data { |
||||
if _, err = w.Write(d); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// WriteContentType writes data with custom ContentType.
|
||||
func (r Data) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, []string{r.ContentType}) |
||||
} |
@ -0,0 +1,58 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var jsonContentType = []string{"application/json; charset=utf-8"} |
||||
|
||||
// JSON common json struct.
|
||||
type JSON struct { |
||||
Code int `json:"code"` |
||||
Message string `json:"message"` |
||||
TTL int `json:"ttl"` |
||||
Data interface{} `json:"data,omitempty"` |
||||
} |
||||
|
||||
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) { |
||||
var jsonBytes []byte |
||||
writeContentType(w, jsonContentType) |
||||
if jsonBytes, err = json.Marshal(obj); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
if _, err = w.Write(jsonBytes); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// Render (JSON) writes data with json ContentType.
|
||||
func (r JSON) Render(w http.ResponseWriter) error { |
||||
// FIXME(zhoujiahui): the TTL field will be configurable in the future
|
||||
if r.TTL <= 0 { |
||||
r.TTL = 1 |
||||
} |
||||
return writeJSON(w, r) |
||||
} |
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (r JSON) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, jsonContentType) |
||||
} |
||||
|
||||
// MapJSON common map json struct.
|
||||
type MapJSON map[string]interface{} |
||||
|
||||
// Render (MapJSON) writes data with json ContentType.
|
||||
func (m MapJSON) Render(w http.ResponseWriter) error { |
||||
return writeJSON(w, m) |
||||
} |
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (m MapJSON) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, jsonContentType) |
||||
} |
@ -0,0 +1,38 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/gogo/protobuf/proto" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var pbContentType = []string{"application/x-protobuf"} |
||||
|
||||
// Render (PB) writes data with protobuf ContentType.
|
||||
func (r PB) Render(w http.ResponseWriter) error { |
||||
if r.TTL <= 0 { |
||||
r.TTL = 1 |
||||
} |
||||
return writePB(w, r) |
||||
} |
||||
|
||||
// WriteContentType write protobuf ContentType.
|
||||
func (r PB) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, pbContentType) |
||||
} |
||||
|
||||
func writePB(w http.ResponseWriter, obj PB) (err error) { |
||||
var pbBytes []byte |
||||
writeContentType(w, pbContentType) |
||||
|
||||
if pbBytes, err = proto.Marshal(&obj); err != nil { |
||||
err = errors.WithStack(err) |
||||
return |
||||
} |
||||
|
||||
if _, err = w.Write(pbBytes); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,26 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// Redirect render for redirect to specified location.
|
||||
type Redirect struct { |
||||
Code int |
||||
Request *http.Request |
||||
Location string |
||||
} |
||||
|
||||
// Render (Redirect) redirect to specified location.
|
||||
func (r Redirect) Render(w http.ResponseWriter) error { |
||||
if (r.Code < 300 || r.Code > 308) && r.Code != 201 { |
||||
return errors.Errorf("Cannot redirect with status code %d", r.Code) |
||||
} |
||||
http.Redirect(w, r.Request, r.Location, r.Code) |
||||
return nil |
||||
} |
||||
|
||||
// WriteContentType noneContentType.
|
||||
func (r Redirect) WriteContentType(http.ResponseWriter) {} |
@ -0,0 +1,30 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"net/http" |
||||
) |
||||
|
||||
// Render http reponse render.
|
||||
type Render interface { |
||||
// Render render it to http response writer.
|
||||
Render(http.ResponseWriter) error |
||||
// WriteContentType write content-type to http response writer.
|
||||
WriteContentType(w http.ResponseWriter) |
||||
} |
||||
|
||||
var ( |
||||
_ Render = JSON{} |
||||
_ Render = MapJSON{} |
||||
_ Render = XML{} |
||||
_ Render = String{} |
||||
_ Render = Redirect{} |
||||
_ Render = Data{} |
||||
_ Render = PB{} |
||||
) |
||||
|
||||
func writeContentType(w http.ResponseWriter, value []string) { |
||||
header := w.Header() |
||||
if val := header["Content-Type"]; len(val) == 0 { |
||||
header["Content-Type"] = value |
||||
} |
||||
} |
@ -0,0 +1,89 @@ |
||||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: pb.proto
|
||||
|
||||
/* |
||||
Package render is a generated protocol buffer package. |
||||
|
||||
It is generated from these files: |
||||
pb.proto |
||||
|
||||
It has these top-level messages: |
||||
PB |
||||
*/ |
||||
package render |
||||
|
||||
import proto "github.com/gogo/protobuf/proto" |
||||
import fmt "fmt" |
||||
import math "math" |
||||
import google_protobuf "github.com/gogo/protobuf/types" |
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal |
||||
var _ = fmt.Errorf |
||||
var _ = math.Inf |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type PB struct { |
||||
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` |
||||
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` |
||||
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"` |
||||
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"` |
||||
} |
||||
|
||||
func (m *PB) Reset() { *m = PB{} } |
||||
func (m *PB) String() string { return proto.CompactTextString(m) } |
||||
func (*PB) ProtoMessage() {} |
||||
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} } |
||||
|
||||
func (m *PB) GetCode() int64 { |
||||
if m != nil { |
||||
return m.Code |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *PB) GetMessage() string { |
||||
if m != nil { |
||||
return m.Message |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func (m *PB) GetTTL() uint64 { |
||||
if m != nil { |
||||
return m.TTL |
||||
} |
||||
return 0 |
||||
} |
||||
|
||||
func (m *PB) GetData() *google_protobuf.Any { |
||||
if m != nil { |
||||
return m.Data |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func init() { |
||||
proto.RegisterType((*PB)(nil), "render.PB") |
||||
} |
||||
|
||||
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) } |
||||
|
||||
var fileDescriptorPb = []byte{ |
||||
// 154 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b, |
||||
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf, |
||||
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94, |
||||
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30, |
||||
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, |
||||
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48, |
||||
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92, |
||||
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66, |
||||
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00, |
||||
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00, |
||||
} |
@ -0,0 +1,14 @@ |
||||
// use under command to generate pb.pb.go |
||||
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto |
||||
syntax = "proto3"; |
||||
package render; |
||||
|
||||
import "google/protobuf/any.proto"; |
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||
|
||||
message PB { |
||||
int64 Code = 1; |
||||
string Message = 2; |
||||
uint64 TTL = 3; |
||||
google.protobuf.Any Data = 4; |
||||
} |
@ -0,0 +1,40 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
var plainContentType = []string{"text/plain; charset=utf-8"} |
||||
|
||||
// String common string struct.
|
||||
type String struct { |
||||
Format string |
||||
Data []interface{} |
||||
} |
||||
|
||||
// Render (String) writes data with custom ContentType.
|
||||
func (r String) Render(w http.ResponseWriter) error { |
||||
return writeString(w, r.Format, r.Data) |
||||
} |
||||
|
||||
// WriteContentType writes string with text/plain ContentType.
|
||||
func (r String) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, plainContentType) |
||||
} |
||||
|
||||
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) { |
||||
writeContentType(w, plainContentType) |
||||
if len(data) > 0 { |
||||
_, err = fmt.Fprintf(w, format, data...) |
||||
} else { |
||||
_, err = io.WriteString(w, format) |
||||
} |
||||
if err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,31 @@ |
||||
package render |
||||
|
||||
import ( |
||||
"encoding/xml" |
||||
"net/http" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// XML common xml struct.
|
||||
type XML struct { |
||||
Code int |
||||
Message string |
||||
Data interface{} |
||||
} |
||||
|
||||
var xmlContentType = []string{"application/xml; charset=utf-8"} |
||||
|
||||
// Render (XML) writes data with xml ContentType.
|
||||
func (r XML) Render(w http.ResponseWriter) (err error) { |
||||
r.WriteContentType(w) |
||||
if err = xml.NewEncoder(w).Encode(r.Data); err != nil { |
||||
err = errors.WithStack(err) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// WriteContentType write xml ContentType.
|
||||
func (r XML) WriteContentType(w http.ResponseWriter) { |
||||
writeContentType(w, xmlContentType) |
||||
} |
@ -0,0 +1,166 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"regexp" |
||||
) |
||||
|
||||
// IRouter http router framework interface.
|
||||
type IRouter interface { |
||||
IRoutes |
||||
Group(string, ...HandlerFunc) *RouterGroup |
||||
} |
||||
|
||||
// IRoutes http router interface.
|
||||
type IRoutes interface { |
||||
UseFunc(...HandlerFunc) IRoutes |
||||
Use(...Handler) IRoutes |
||||
|
||||
Handle(string, string, ...HandlerFunc) IRoutes |
||||
HEAD(string, ...HandlerFunc) IRoutes |
||||
GET(string, ...HandlerFunc) IRoutes |
||||
POST(string, ...HandlerFunc) IRoutes |
||||
PUT(string, ...HandlerFunc) IRoutes |
||||
DELETE(string, ...HandlerFunc) IRoutes |
||||
} |
||||
|
||||
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
|
||||
// and an array of handlers (middleware).
|
||||
type RouterGroup struct { |
||||
Handlers []HandlerFunc |
||||
basePath string |
||||
engine *Engine |
||||
root bool |
||||
baseConfig *MethodConfig |
||||
} |
||||
|
||||
var _ IRouter = &RouterGroup{} |
||||
|
||||
// Use adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) Use(middleware ...Handler) IRoutes { |
||||
for _, m := range middleware { |
||||
group.Handlers = append(group.Handlers, m.ServeHTTP) |
||||
} |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// UseFunc adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||
group.Handlers = append(group.Handlers, middleware...) |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
|
||||
// For example, all the routes that use a common middlware for authorization could be grouped.
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { |
||||
return &RouterGroup{ |
||||
Handlers: group.combineHandlers(handlers), |
||||
basePath: group.calculateAbsolutePath(relativePath), |
||||
engine: group.engine, |
||||
root: false, |
||||
} |
||||
} |
||||
|
||||
// SetMethodConfig is used to set config on specified method
|
||||
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup { |
||||
group.baseConfig = config |
||||
return group |
||||
} |
||||
|
||||
// BasePath router group base path.
|
||||
func (group *RouterGroup) BasePath() string { |
||||
return group.basePath |
||||
} |
||||
|
||||
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
absolutePath := group.calculateAbsolutePath(relativePath) |
||||
injections := group.injections(relativePath) |
||||
handlers = group.combineHandlers(injections, handlers) |
||||
group.engine.addRoute(httpMethod, absolutePath, handlers...) |
||||
if group.baseConfig != nil { |
||||
group.engine.SetMethodConfig(absolutePath, group.baseConfig) |
||||
} |
||||
return group.returnObj() |
||||
} |
||||
|
||||
// Handle registers a new request handle and middleware with the given path and method.
|
||||
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
|
||||
// See the example code in doc.
|
||||
//
|
||||
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
|
||||
// functions can be used.
|
||||
//
|
||||
// This function is intended for bulk loading and to allow the usage of less
|
||||
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||
// communication with a proxy).
|
||||
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil { |
||||
panic("http method " + httpMethod + " is not valid") |
||||
} |
||||
return group.handle(httpMethod, relativePath, handlers...) |
||||
} |
||||
|
||||
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
|
||||
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("HEAD", relativePath, handlers...) |
||||
} |
||||
|
||||
// GET is a shortcut for router.Handle("GET", path, handle).
|
||||
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("GET", relativePath, handlers...) |
||||
} |
||||
|
||||
// POST is a shortcut for router.Handle("POST", path, handle).
|
||||
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("POST", relativePath, handlers...) |
||||
} |
||||
|
||||
// PUT is a shortcut for router.Handle("PUT", path, handle).
|
||||
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("PUT", relativePath, handlers...) |
||||
} |
||||
|
||||
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
|
||||
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||
return group.handle("DELETE", relativePath, handlers...) |
||||
} |
||||
|
||||
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc { |
||||
finalSize := len(group.Handlers) |
||||
for _, handlers := range handlerGroups { |
||||
finalSize += len(handlers) |
||||
} |
||||
if finalSize >= int(_abortIndex) { |
||||
panic("too many handlers") |
||||
} |
||||
mergedHandlers := make([]HandlerFunc, finalSize) |
||||
copy(mergedHandlers, group.Handlers) |
||||
position := len(group.Handlers) |
||||
for _, handlers := range handlerGroups { |
||||
copy(mergedHandlers[position:], handlers) |
||||
position += len(handlers) |
||||
} |
||||
return mergedHandlers |
||||
} |
||||
|
||||
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { |
||||
return joinPaths(group.basePath, relativePath) |
||||
} |
||||
|
||||
func (group *RouterGroup) returnObj() IRoutes { |
||||
if group.root { |
||||
return group.engine |
||||
} |
||||
return group |
||||
} |
||||
|
||||
// injections is
|
||||
func (group *RouterGroup) injections(relativePath string) []HandlerFunc { |
||||
absPath := group.calculateAbsolutePath(relativePath) |
||||
for _, injection := range group.engine.injections { |
||||
if !injection.pattern.MatchString(absPath) { |
||||
continue |
||||
} |
||||
return injection.handlers |
||||
} |
||||
return nil |
||||
} |
@ -0,0 +1,445 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"os" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/bilibili/Kratos/pkg/conf/dsn" |
||||
"github.com/bilibili/Kratos/pkg/log" |
||||
"github.com/bilibili/Kratos/pkg/net/ip" |
||||
"github.com/bilibili/Kratos/pkg/net/metadata" |
||||
"github.com/bilibili/Kratos/pkg/stat" |
||||
xtime "github.com/bilibili/Kratos/pkg/time" |
||||
|
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
const ( |
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
) |
||||
|
||||
var ( |
||||
_ IRouter = &Engine{} |
||||
stats = stat.HTTPServer |
||||
|
||||
_httpDSN string |
||||
) |
||||
|
||||
func init() { |
||||
addFlag(flag.CommandLine) |
||||
} |
||||
|
||||
func addFlag(fs *flag.FlagSet) { |
||||
v := os.Getenv("HTTP") |
||||
if v == "" { |
||||
v = "tcp://0.0.0.0:8000/?timeout=1s" |
||||
} |
||||
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.") |
||||
} |
||||
|
||||
func parseDSN(rawdsn string) *ServerConfig { |
||||
conf := new(ServerConfig) |
||||
d, err := dsn.Parse(rawdsn) |
||||
if err != nil { |
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||
} |
||||
if _, err = d.Bind(conf); err != nil { |
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||
} |
||||
return conf |
||||
} |
||||
|
||||
// Handler responds to an HTTP request.
|
||||
type Handler interface { |
||||
ServeHTTP(c *Context) |
||||
} |
||||
|
||||
// HandlerFunc http request handler function.
|
||||
type HandlerFunc func(*Context) |
||||
|
||||
// ServeHTTP calls f(ctx).
|
||||
func (f HandlerFunc) ServeHTTP(c *Context) { |
||||
f(c) |
||||
} |
||||
|
||||
// ServerConfig is the bm server config model
|
||||
type ServerConfig struct { |
||||
Network string `dsn:"network"` |
||||
// FIXME: rename to Address
|
||||
Addr string `dsn:"address"` |
||||
Timeout xtime.Duration `dsn:"query.timeout"` |
||||
ReadTimeout xtime.Duration `dsn:"query.readTimeout"` |
||||
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"` |
||||
} |
||||
|
||||
// MethodConfig is
|
||||
type MethodConfig struct { |
||||
Timeout xtime.Duration |
||||
} |
||||
|
||||
// Start listen and serve bm engine by given DSN.
|
||||
func (engine *Engine) Start() error { |
||||
conf := engine.conf |
||||
l, err := net.Listen(conf.Network, conf.Addr) |
||||
if err != nil { |
||||
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr) |
||||
return err |
||||
} |
||||
|
||||
log.Info("blademaster: start http listen addr: %s", conf.Addr) |
||||
server := &http.Server{ |
||||
ReadTimeout: time.Duration(conf.ReadTimeout), |
||||
WriteTimeout: time.Duration(conf.WriteTimeout), |
||||
} |
||||
go func() { |
||||
if err := engine.RunServer(server, l); err != nil { |
||||
if errors.Cause(err) == http.ErrServerClosed { |
||||
log.Info("blademaster: server closed") |
||||
return |
||||
} |
||||
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l)) |
||||
} |
||||
}() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
|
||||
// Create an instance of Engine, by using New() or Default()
|
||||
type Engine struct { |
||||
RouterGroup |
||||
|
||||
lock sync.RWMutex |
||||
conf *ServerConfig |
||||
|
||||
address string |
||||
|
||||
mux *http.ServeMux // http mux router
|
||||
server atomic.Value // store *http.Server
|
||||
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
|
||||
|
||||
pcLock sync.RWMutex |
||||
methodConfigs map[string]*MethodConfig |
||||
|
||||
injections []injection |
||||
} |
||||
|
||||
type injection struct { |
||||
pattern *regexp.Regexp |
||||
handlers []HandlerFunc |
||||
} |
||||
|
||||
// New returns a new blank Engine instance without any middleware attached.
|
||||
//
|
||||
// Deprecated: please use NewServer.
|
||||
func New() *Engine { |
||||
engine := &Engine{ |
||||
RouterGroup: RouterGroup{ |
||||
Handlers: nil, |
||||
basePath: "/", |
||||
root: true, |
||||
}, |
||||
address: ip.InternalIP(), |
||||
conf: &ServerConfig{ |
||||
Timeout: xtime.Duration(time.Second), |
||||
}, |
||||
mux: http.NewServeMux(), |
||||
metastore: make(map[string]map[string]interface{}), |
||||
methodConfigs: make(map[string]*MethodConfig), |
||||
injections: make([]injection, 0), |
||||
} |
||||
engine.RouterGroup.engine = engine |
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor()) |
||||
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||
startPerf() |
||||
return engine |
||||
} |
||||
|
||||
// NewServer returns a new blank Engine instance without any middleware attached.
|
||||
func NewServer(conf *ServerConfig) *Engine { |
||||
if conf == nil { |
||||
if !flag.Parsed() { |
||||
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n") |
||||
} |
||||
conf = parseDSN(_httpDSN) |
||||
} else { |
||||
fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n") |
||||
} |
||||
|
||||
engine := &Engine{ |
||||
RouterGroup: RouterGroup{ |
||||
Handlers: nil, |
||||
basePath: "/", |
||||
root: true, |
||||
}, |
||||
address: ip.InternalIP(), |
||||
mux: http.NewServeMux(), |
||||
metastore: make(map[string]map[string]interface{}), |
||||
methodConfigs: make(map[string]*MethodConfig), |
||||
} |
||||
if err := engine.SetConfig(conf); err != nil { |
||||
panic(err) |
||||
} |
||||
engine.RouterGroup.engine = engine |
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor()) |
||||
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||
startPerf() |
||||
return engine |
||||
} |
||||
|
||||
// SetMethodConfig is used to set config on specified path
|
||||
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) { |
||||
engine.pcLock.Lock() |
||||
engine.methodConfigs[path] = mc |
||||
engine.pcLock.Unlock() |
||||
} |
||||
|
||||
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||
func DefaultServer(conf *ServerConfig) *Engine { |
||||
engine := NewServer(conf) |
||||
engine.Use(Recovery(), Trace(), Logger()) |
||||
return engine |
||||
} |
||||
|
||||
// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||
//
|
||||
// Deprecated: please use DefaultServer.
|
||||
func Default() *Engine { |
||||
engine := New() |
||||
engine.Use(Recovery(), Trace(), Logger()) |
||||
return engine |
||||
} |
||||
|
||||
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) { |
||||
if path[0] != '/' { |
||||
panic("blademaster: path must begin with '/'") |
||||
} |
||||
if method == "" { |
||||
panic("blademaster: HTTP method can not be empty") |
||||
} |
||||
if len(handlers) == 0 { |
||||
panic("blademaster: there must be at least one handler") |
||||
} |
||||
if _, ok := engine.metastore[path]; !ok { |
||||
engine.metastore[path] = make(map[string]interface{}) |
||||
} |
||||
engine.metastore[path]["method"] = method |
||||
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { |
||||
c := &Context{ |
||||
Context: nil, |
||||
engine: engine, |
||||
index: -1, |
||||
handlers: nil, |
||||
Keys: nil, |
||||
method: "", |
||||
Error: nil, |
||||
} |
||||
|
||||
c.Request = req |
||||
c.Writer = w |
||||
c.handlers = handlers |
||||
c.method = method |
||||
|
||||
engine.handleContext(c) |
||||
}) |
||||
} |
||||
|
||||
// SetConfig is used to set the engine configuration.
|
||||
// Only the valid config will be loaded.
|
||||
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) { |
||||
if conf.Timeout <= 0 { |
||||
return errors.New("blademaster: config timeout must greater than 0") |
||||
} |
||||
if conf.Network == "" { |
||||
conf.Network = "tcp" |
||||
} |
||||
engine.lock.Lock() |
||||
engine.conf = conf |
||||
engine.lock.Unlock() |
||||
return |
||||
} |
||||
|
||||
func (engine *Engine) methodConfig(path string) *MethodConfig { |
||||
engine.pcLock.RLock() |
||||
mc := engine.methodConfigs[path] |
||||
engine.pcLock.RUnlock() |
||||
return mc |
||||
} |
||||
|
||||
func (engine *Engine) handleContext(c *Context) { |
||||
var cancel func() |
||||
req := c.Request |
||||
ctype := req.Header.Get("Content-Type") |
||||
switch { |
||||
case strings.Contains(ctype, "multipart/form-data"): |
||||
req.ParseMultipartForm(defaultMaxMemory) |
||||
default: |
||||
req.ParseForm() |
||||
} |
||||
// get derived timeout from http request header,
|
||||
// compare with the engine configured,
|
||||
// and use the minimum one
|
||||
engine.lock.RLock() |
||||
tm := time.Duration(engine.conf.Timeout) |
||||
engine.lock.RUnlock() |
||||
// the method config is preferred
|
||||
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil { |
||||
tm = time.Duration(pc.Timeout) |
||||
} |
||||
if ctm := timeout(req); ctm > 0 && tm > ctm { |
||||
tm = ctm |
||||
} |
||||
md := metadata.MD{ |
||||
metadata.Color: color(req), |
||||
metadata.RemoteIP: remoteIP(req), |
||||
metadata.RemotePort: remotePort(req), |
||||
metadata.Caller: caller(req), |
||||
metadata.Mirror: mirror(req), |
||||
} |
||||
ctx := metadata.NewContext(context.Background(), md) |
||||
if tm > 0 { |
||||
c.Context, cancel = context.WithTimeout(ctx, tm) |
||||
} else { |
||||
c.Context, cancel = context.WithCancel(ctx) |
||||
} |
||||
defer cancel() |
||||
c.Next() |
||||
} |
||||
|
||||
// Router return a http.Handler for using http.ListenAndServe() directly.
|
||||
func (engine *Engine) Router() http.Handler { |
||||
return engine.mux |
||||
} |
||||
|
||||
// Server is used to load stored http server.
|
||||
func (engine *Engine) Server() *http.Server { |
||||
s, ok := engine.server.Load().(*http.Server) |
||||
if !ok { |
||||
return nil |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// Shutdown the http server without interrupting active connections.
|
||||
func (engine *Engine) Shutdown(ctx context.Context) error { |
||||
server := engine.Server() |
||||
if server == nil { |
||||
return errors.New("blademaster: no server") |
||||
} |
||||
return errors.WithStack(server.Shutdown(ctx)) |
||||
} |
||||
|
||||
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||
engine.RouterGroup.UseFunc(middleware...) |
||||
return engine |
||||
} |
||||
|
||||
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) Use(middleware ...Handler) IRoutes { |
||||
engine.RouterGroup.Use(middleware...) |
||||
return engine |
||||
} |
||||
|
||||
// Ping is used to set the general HTTP ping handler.
|
||||
func (engine *Engine) Ping(handler HandlerFunc) { |
||||
engine.GET("/monitor/ping", handler) |
||||
} |
||||
|
||||
// Register is used to export metadata to discovery.
|
||||
func (engine *Engine) Register(handler HandlerFunc) { |
||||
engine.GET("/register", handler) |
||||
} |
||||
|
||||
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
|
||||
// It is a shortcut for http.ListenAndServe(addr, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) Run(addr ...string) (err error) { |
||||
address := resolveAddress(addr) |
||||
server := &http.Server{ |
||||
Addr: address, |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.ListenAndServe(); err != nil { |
||||
err = errors.Wrapf(err, "addrs: %v", addr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
|
||||
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { |
||||
server := &http.Server{ |
||||
Addr: addr, |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { |
||||
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
|
||||
// through the specified unix socket (ie. a file).
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunUnix(file string) (err error) { |
||||
os.Remove(file) |
||||
listener, err := net.Listen("unix", file) |
||||
if err != nil { |
||||
err = errors.Wrapf(err, "unix: %s", file) |
||||
return |
||||
} |
||||
defer listener.Close() |
||||
server := &http.Server{ |
||||
Handler: engine.mux, |
||||
} |
||||
engine.server.Store(server) |
||||
if err = server.Serve(listener); err != nil { |
||||
err = errors.Wrapf(err, "unix: %s", file) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// RunServer will serve and start listening HTTP requests by given server and listener.
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) { |
||||
server.Handler = engine.mux |
||||
engine.server.Store(server) |
||||
if err = server.Serve(l); err != nil { |
||||
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l) |
||||
return |
||||
} |
||||
return |
||||
} |
||||
|
||||
func (engine *Engine) metadata() HandlerFunc { |
||||
return func(c *Context) { |
||||
c.JSON(engine.metastore, nil) |
||||
} |
||||
} |
||||
|
||||
// Inject is
|
||||
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) { |
||||
engine.injections = append(engine.injections, injection{ |
||||
pattern: regexp.MustCompile(pattern), |
||||
handlers: handlers, |
||||
}) |
||||
} |
@ -0,0 +1,42 @@ |
||||
package blademaster |
||||
|
||||
import ( |
||||
"os" |
||||
"path" |
||||
) |
||||
|
||||
func lastChar(str string) uint8 { |
||||
if str == "" { |
||||
panic("The length of the string can't be 0") |
||||
} |
||||
return str[len(str)-1] |
||||
} |
||||
|
||||
func joinPaths(absolutePath, relativePath string) string { |
||||
if relativePath == "" { |
||||
return absolutePath |
||||
} |
||||
|
||||
finalPath := path.Join(absolutePath, relativePath) |
||||
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' |
||||
if appendSlash { |
||||
return finalPath + "/" |
||||
} |
||||
return finalPath |
||||
} |
||||
|
||||
func resolveAddress(addr []string) string { |
||||
switch len(addr) { |
||||
case 0: |
||||
if port := os.Getenv("PORT"); port != "" { |
||||
//debugPrint("Environment variable PORT=\"%s\"", port)
|
||||
return ":" + port |
||||
} |
||||
//debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
|
||||
return ":8080" |
||||
case 1: |
||||
return addr[0] |
||||
default: |
||||
panic("too much parameters") |
||||
} |
||||
} |
Loading…
Reference in new issue