blademaster initial (#6)
parent
1efe0a084e
commit
96d32e866a
@ -0,0 +1,85 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"gopkg.in/go-playground/validator.v9" |
||||||
|
) |
||||||
|
|
||||||
|
// MIME
|
||||||
|
const ( |
||||||
|
MIMEJSON = "application/json" |
||||||
|
MIMEHTML = "text/html" |
||||||
|
MIMEXML = "application/xml" |
||||||
|
MIMEXML2 = "text/xml" |
||||||
|
MIMEPlain = "text/plain" |
||||||
|
MIMEPOSTForm = "application/x-www-form-urlencoded" |
||||||
|
MIMEMultipartPOSTForm = "multipart/form-data" |
||||||
|
) |
||||||
|
|
||||||
|
// Binding http binding request interface.
|
||||||
|
type Binding interface { |
||||||
|
Name() string |
||||||
|
Bind(*http.Request, interface{}) error |
||||||
|
} |
||||||
|
|
||||||
|
// StructValidator http validator interface.
|
||||||
|
type StructValidator interface { |
||||||
|
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||||
|
// If the received type is not a struct, any validation should be skipped and nil must be returned.
|
||||||
|
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||||
|
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||||
|
// Otherwise nil must be returned.
|
||||||
|
ValidateStruct(interface{}) error |
||||||
|
|
||||||
|
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||||
|
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||||
|
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||||
|
RegisterValidation(string, validator.Func) error |
||||||
|
} |
||||||
|
|
||||||
|
// Validator default validator.
|
||||||
|
var Validator StructValidator = &defaultValidator{} |
||||||
|
|
||||||
|
// Binding
|
||||||
|
var ( |
||||||
|
JSON = jsonBinding{} |
||||||
|
XML = xmlBinding{} |
||||||
|
Form = formBinding{} |
||||||
|
Query = queryBinding{} |
||||||
|
FormPost = formPostBinding{} |
||||||
|
FormMultipart = formMultipartBinding{} |
||||||
|
) |
||||||
|
|
||||||
|
// Default get by binding type by method and contexttype.
|
||||||
|
func Default(method, contentType string) Binding { |
||||||
|
if method == "GET" { |
||||||
|
return Form |
||||||
|
} |
||||||
|
|
||||||
|
contentType = stripContentTypeParam(contentType) |
||||||
|
switch contentType { |
||||||
|
case MIMEJSON: |
||||||
|
return JSON |
||||||
|
case MIMEXML, MIMEXML2: |
||||||
|
return XML |
||||||
|
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
|
||||||
|
return Form |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func validate(obj interface{}) error { |
||||||
|
if Validator == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return Validator.ValidateStruct(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func stripContentTypeParam(contentType string) string { |
||||||
|
i := strings.Index(contentType, ";") |
||||||
|
if i != -1 { |
||||||
|
contentType = contentType[:i] |
||||||
|
} |
||||||
|
return contentType |
||||||
|
} |
@ -0,0 +1,342 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"mime/multipart" |
||||||
|
"net/http" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
) |
||||||
|
|
||||||
|
type FooStruct struct { |
||||||
|
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"` |
||||||
|
} |
||||||
|
|
||||||
|
type FooBarStruct struct { |
||||||
|
FooStruct |
||||||
|
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"` |
||||||
|
Slice []string `form:"slice" validate:"max=10"` |
||||||
|
} |
||||||
|
|
||||||
|
type ComplexDefaultStruct struct { |
||||||
|
Int int `form:"int" default:"999"` |
||||||
|
String string `form:"string" default:"default-string"` |
||||||
|
Bool bool `form:"bool" default:"false"` |
||||||
|
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"` |
||||||
|
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"` |
||||||
|
} |
||||||
|
|
||||||
|
type Int8SliceStruct struct { |
||||||
|
State []int8 `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
type Int64SliceStruct struct { |
||||||
|
State []int64 `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
type StringSliceStruct struct { |
||||||
|
State []string `form:"state,split"` |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingDefault(t *testing.T) { |
||||||
|
assert.Equal(t, Default("GET", ""), Form) |
||||||
|
assert.Equal(t, Default("GET", MIMEJSON), Form) |
||||||
|
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEJSON), JSON) |
||||||
|
assert.Equal(t, Default("PUT", MIMEJSON), JSON) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON) |
||||||
|
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEXML), XML) |
||||||
|
assert.Equal(t, Default("PUT", MIMEXML2), XML) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEPOSTForm), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form) |
||||||
|
|
||||||
|
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form) |
||||||
|
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form) |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
func TestStripContentType(t *testing.T) { |
||||||
|
c1 := "application/vnd.mozilla.xul+xml" |
||||||
|
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8" |
||||||
|
assert.Equal(t, stripContentTypeParam(c1), c1) |
||||||
|
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindInt8Form(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(Int8SliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3}, q.State) |
||||||
|
|
||||||
|
params = "state=1,2,3,256" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.Error(t, Form.Bind(req, q)) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
|
||||||
|
params = "state=1,,2" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int8SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.EqualValues(t, []int8{1, 2}, q.State) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindInt64Form(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(Int64SliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3}, q.State) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(Int64SliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindStringForm(t *testing.T) { |
||||||
|
params := "state=1,2,3" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(StringSliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []string{"1", "2", "3"}, q.State) |
||||||
|
|
||||||
|
params = "state=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(StringSliceStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Len(t, q.State, 0) |
||||||
|
|
||||||
|
params = "state=p,,p" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(StringSliceStruct) |
||||||
|
Form.Bind(req, q) |
||||||
|
assert.EqualValues(t, []string{"p", "p"}, q.State) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingJSON(t *testing.T) { |
||||||
|
testBodyBinding(t, |
||||||
|
JSON, "json", |
||||||
|
"/", "/", |
||||||
|
`{"foo": "bar"}`, `{"bar": "foo"}`) |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingForm(t *testing.T) { |
||||||
|
testFormBinding(t, "POST", |
||||||
|
"/", "/", |
||||||
|
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingForm2(t *testing.T) { |
||||||
|
testFormBinding(t, "GET", |
||||||
|
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||||
|
"", "") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingQuery(t *testing.T) { |
||||||
|
testQueryBinding(t, "POST", |
||||||
|
"/?foo=bar&bar=foo", "/", |
||||||
|
"foo=unused", "bar2=foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingQuery2(t *testing.T) { |
||||||
|
testQueryBinding(t, "GET", |
||||||
|
"/?foo=bar&bar=foo", "/?bar2=foo", |
||||||
|
"foo=unused", "") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingXML(t *testing.T) { |
||||||
|
testBodyBinding(t, |
||||||
|
XML, "xml", |
||||||
|
"/", "/", |
||||||
|
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>") |
||||||
|
} |
||||||
|
|
||||||
|
func createFormPostRequest() *http.Request { |
||||||
|
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) |
||||||
|
req.Header.Set("Content-Type", MIMEPOSTForm) |
||||||
|
return req |
||||||
|
} |
||||||
|
|
||||||
|
func createFormMultipartRequest() *http.Request { |
||||||
|
boundary := "--testboundary" |
||||||
|
body := new(bytes.Buffer) |
||||||
|
mw := multipart.NewWriter(body) |
||||||
|
defer mw.Close() |
||||||
|
|
||||||
|
mw.SetBoundary(boundary) |
||||||
|
mw.WriteField("foo", "bar") |
||||||
|
mw.WriteField("bar", "foo") |
||||||
|
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) |
||||||
|
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) |
||||||
|
return req |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingFormPost(t *testing.T) { |
||||||
|
req := createFormPostRequest() |
||||||
|
var obj FooBarStruct |
||||||
|
FormPost.Bind(req, &obj) |
||||||
|
|
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestBindingFormMultipart(t *testing.T) { |
||||||
|
req := createFormMultipartRequest() |
||||||
|
var obj FooBarStruct |
||||||
|
FormMultipart.Bind(req, &obj) |
||||||
|
|
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidationFails(t *testing.T) { |
||||||
|
var obj FooStruct |
||||||
|
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidationDisabled(t *testing.T) { |
||||||
|
backup := Validator |
||||||
|
Validator = nil |
||||||
|
defer func() { Validator = backup }() |
||||||
|
|
||||||
|
var obj FooStruct |
||||||
|
req := requestWithBody("POST", "/", `{"bar": "foo"}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestExistsSucceeds(t *testing.T) { |
||||||
|
type HogeStruct struct { |
||||||
|
Hoge *int `json:"hoge" binding:"exists"` |
||||||
|
} |
||||||
|
|
||||||
|
var obj HogeStruct |
||||||
|
req := requestWithBody("POST", "/", `{"hoge": 0}`) |
||||||
|
err := JSON.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestFormDefaultValue(t *testing.T) { |
||||||
|
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8" |
||||||
|
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q := new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 333, q.Int) |
||||||
|
assert.Equal(t, "hello", q.String) |
||||||
|
assert.Equal(t, true, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "string=hello&bool=false" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "hello", q.String) |
||||||
|
assert.Equal(t, false, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "strings=hello" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "default-string", q.String) |
||||||
|
assert.Equal(t, false, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
|
||||||
|
params = "int=&string=&bool=true&int64_slice=&int8_slice=" |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
q = new(ComplexDefaultStruct) |
||||||
|
assert.NoError(t, Form.Bind(req, q)) |
||||||
|
assert.Equal(t, 999, q.Int) |
||||||
|
assert.Equal(t, "default-string", q.String) |
||||||
|
assert.Equal(t, true, q.Bool) |
||||||
|
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice) |
||||||
|
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice) |
||||||
|
} |
||||||
|
|
||||||
|
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||||
|
b := Form |
||||||
|
assert.Equal(t, b.Name(), "form") |
||||||
|
|
||||||
|
obj := FooBarStruct{} |
||||||
|
req := requestWithBody(method, path, body) |
||||||
|
if method == "POST" { |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
} |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
|
||||||
|
obj = FooBarStruct{} |
||||||
|
req = requestWithBody(method, badPath, badBody) |
||||||
|
err = JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) { |
||||||
|
b := Query |
||||||
|
assert.Equal(t, b.Name(), "query") |
||||||
|
|
||||||
|
obj := FooBarStruct{} |
||||||
|
req := requestWithBody(method, path, body) |
||||||
|
if method == "POST" { |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
} |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
assert.Equal(t, obj.Bar, "foo") |
||||||
|
} |
||||||
|
|
||||||
|
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) { |
||||||
|
assert.Equal(t, b.Name(), name) |
||||||
|
|
||||||
|
obj := FooStruct{} |
||||||
|
req := requestWithBody("POST", path, body) |
||||||
|
err := b.Bind(req, &obj) |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.Equal(t, obj.Foo, "bar") |
||||||
|
|
||||||
|
obj = FooStruct{} |
||||||
|
req = requestWithBody("POST", badPath, badBody) |
||||||
|
err = JSON.Bind(req, &obj) |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func requestWithBody(method, path, body string) (req *http.Request) { |
||||||
|
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body)) |
||||||
|
return |
||||||
|
} |
||||||
|
func BenchmarkBindingForm(b *testing.B) { |
||||||
|
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w") |
||||||
|
req.Header.Add("Content-Type", MIMEPOSTForm) |
||||||
|
f := Form |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
obj := FooBarStruct{} |
||||||
|
f.Bind(req, &obj) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,45 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"reflect" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"gopkg.in/go-playground/validator.v9" |
||||||
|
) |
||||||
|
|
||||||
|
type defaultValidator struct { |
||||||
|
once sync.Once |
||||||
|
validate *validator.Validate |
||||||
|
} |
||||||
|
|
||||||
|
var _ StructValidator = &defaultValidator{} |
||||||
|
|
||||||
|
func (v *defaultValidator) ValidateStruct(obj interface{}) error { |
||||||
|
if kindOfData(obj) == reflect.Struct { |
||||||
|
v.lazyinit() |
||||||
|
if err := v.validate.Struct(obj); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error { |
||||||
|
v.lazyinit() |
||||||
|
return v.validate.RegisterValidation(key, fn) |
||||||
|
} |
||||||
|
|
||||||
|
func (v *defaultValidator) lazyinit() { |
||||||
|
v.once.Do(func() { |
||||||
|
v.validate = validator.New() |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func kindOfData(data interface{}) reflect.Kind { |
||||||
|
value := reflect.ValueOf(data) |
||||||
|
valueType := value.Kind() |
||||||
|
if valueType == reflect.Ptr { |
||||||
|
valueType = value.Elem().Kind() |
||||||
|
} |
||||||
|
return valueType |
||||||
|
} |
@ -0,0 +1,113 @@ |
|||||||
|
// Code generated by protoc-gen-go.
|
||||||
|
// source: test.proto
|
||||||
|
// DO NOT EDIT!
|
||||||
|
|
||||||
|
/* |
||||||
|
Package example is a generated protocol buffer package. |
||||||
|
|
||||||
|
It is generated from these files: |
||||||
|
test.proto |
||||||
|
|
||||||
|
It has these top-level messages: |
||||||
|
Test |
||||||
|
*/ |
||||||
|
package example |
||||||
|
|
||||||
|
import proto "github.com/golang/protobuf/proto" |
||||||
|
import math "math" |
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal |
||||||
|
var _ = math.Inf |
||||||
|
|
||||||
|
type FOO int32 |
||||||
|
|
||||||
|
const ( |
||||||
|
FOO_X FOO = 17 |
||||||
|
) |
||||||
|
|
||||||
|
var FOO_name = map[int32]string{ |
||||||
|
17: "X", |
||||||
|
} |
||||||
|
var FOO_value = map[string]int32{ |
||||||
|
"X": 17, |
||||||
|
} |
||||||
|
|
||||||
|
func (x FOO) Enum() *FOO { |
||||||
|
p := new(FOO) |
||||||
|
*p = x |
||||||
|
return p |
||||||
|
} |
||||||
|
func (x FOO) String() string { |
||||||
|
return proto.EnumName(FOO_name, int32(x)) |
||||||
|
} |
||||||
|
func (x *FOO) UnmarshalJSON(data []byte) error { |
||||||
|
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
*x = FOO(value) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
type Test struct { |
||||||
|
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` |
||||||
|
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` |
||||||
|
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` |
||||||
|
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` |
||||||
|
XXX_unrecognized []byte `json:"-"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) Reset() { *m = Test{} } |
||||||
|
func (m *Test) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*Test) ProtoMessage() {} |
||||||
|
|
||||||
|
const Default_Test_Type int32 = 77 |
||||||
|
|
||||||
|
func (m *Test) GetLabel() string { |
||||||
|
if m != nil && m.Label != nil { |
||||||
|
return *m.Label |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetType() int32 { |
||||||
|
if m != nil && m.Type != nil { |
||||||
|
return *m.Type |
||||||
|
} |
||||||
|
return Default_Test_Type |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetReps() []int64 { |
||||||
|
if m != nil { |
||||||
|
return m.Reps |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test) GetOptionalgroup() *Test_OptionalGroup { |
||||||
|
if m != nil { |
||||||
|
return m.Optionalgroup |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
type Test_OptionalGroup struct { |
||||||
|
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` |
||||||
|
XXX_unrecognized []byte `json:"-"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } |
||||||
|
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*Test_OptionalGroup) ProtoMessage() {} |
||||||
|
|
||||||
|
func (m *Test_OptionalGroup) GetRequiredField() string { |
||||||
|
if m != nil && m.RequiredField != nil { |
||||||
|
return *m.RequiredField |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func init() { |
||||||
|
proto.RegisterEnum("example.FOO", FOO_name, FOO_value) |
||||||
|
} |
@ -0,0 +1,12 @@ |
|||||||
|
package example; |
||||||
|
|
||||||
|
enum FOO {X=17;}; |
||||||
|
|
||||||
|
message Test { |
||||||
|
required string label = 1; |
||||||
|
optional int32 type = 2[default=77]; |
||||||
|
repeated int64 reps = 3; |
||||||
|
optional group OptionalGroup = 4{ |
||||||
|
required string RequiredField = 5; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"log" |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
type Arg struct { |
||||||
|
Max int64 `form:"max" validate:"max=10"` |
||||||
|
Min int64 `form:"min" validate:"min=2"` |
||||||
|
Range int64 `form:"range" validate:"min=1,max=10"` |
||||||
|
// use split option to split arg 1,2,3 into slice [1 2 3]
|
||||||
|
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
|
||||||
|
Slice []int64 `form:"slice,split" validate:"min=1"` |
||||||
|
} |
||||||
|
|
||||||
|
func ExampleBinding() { |
||||||
|
req := initHTTP("max=9&min=3&range=3&slice=1,2,3") |
||||||
|
arg := new(Arg) |
||||||
|
if err := Form.Bind(req, arg); err != nil { |
||||||
|
log.Fatal(err) |
||||||
|
} |
||||||
|
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice) |
||||||
|
// Output:
|
||||||
|
// arg.Max 9
|
||||||
|
// arg.Min 3
|
||||||
|
// arg.Range 3
|
||||||
|
// arg.Slice [1 2 3]
|
||||||
|
} |
||||||
|
|
||||||
|
func initHTTP(params string) (req *http.Request) { |
||||||
|
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil) |
||||||
|
req.ParseForm() |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,55 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const defaultMemory = 32 * 1024 * 1024 |
||||||
|
|
||||||
|
type formBinding struct{} |
||||||
|
type formPostBinding struct{} |
||||||
|
type formMultipartBinding struct{} |
||||||
|
|
||||||
|
func (f formBinding) Name() string { |
||||||
|
return "form" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseForm(); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.Form); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func (f formPostBinding) Name() string { |
||||||
|
return "form-urlencoded" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseForm(); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.PostForm); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
||||||
|
|
||||||
|
func (f formMultipartBinding) Name() string { |
||||||
|
return "multipart/form-data" |
||||||
|
} |
||||||
|
|
||||||
|
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
if err := req.ParseMultipartForm(defaultMemory); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
if err := mapForm(obj, req.MultipartForm.Value); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,276 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"reflect" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// scache struct reflect type cache.
|
||||||
|
var scache = &cache{ |
||||||
|
data: make(map[reflect.Type]*sinfo), |
||||||
|
} |
||||||
|
|
||||||
|
type cache struct { |
||||||
|
data map[reflect.Type]*sinfo |
||||||
|
mutex sync.RWMutex |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cache) get(obj reflect.Type) (s *sinfo) { |
||||||
|
var ok bool |
||||||
|
c.mutex.RLock() |
||||||
|
if s, ok = c.data[obj]; !ok { |
||||||
|
c.mutex.RUnlock() |
||||||
|
s = c.set(obj) |
||||||
|
return |
||||||
|
} |
||||||
|
c.mutex.RUnlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (c *cache) set(obj reflect.Type) (s *sinfo) { |
||||||
|
s = new(sinfo) |
||||||
|
tp := obj.Elem() |
||||||
|
for i := 0; i < tp.NumField(); i++ { |
||||||
|
fd := new(field) |
||||||
|
fd.tp = tp.Field(i) |
||||||
|
tag := fd.tp.Tag.Get("form") |
||||||
|
fd.name, fd.option = parseTag(tag) |
||||||
|
if defV := fd.tp.Tag.Get("default"); defV != "" { |
||||||
|
dv := reflect.New(fd.tp.Type).Elem() |
||||||
|
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option) |
||||||
|
fd.hasDefault = true |
||||||
|
fd.defaultValue = dv |
||||||
|
} |
||||||
|
s.field = append(s.field, fd) |
||||||
|
} |
||||||
|
c.mutex.Lock() |
||||||
|
c.data[obj] = s |
||||||
|
c.mutex.Unlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
type sinfo struct { |
||||||
|
field []*field |
||||||
|
} |
||||||
|
|
||||||
|
type field struct { |
||||||
|
tp reflect.StructField |
||||||
|
name string |
||||||
|
option tagOptions |
||||||
|
|
||||||
|
hasDefault bool // if field had default value
|
||||||
|
defaultValue reflect.Value // field default value
|
||||||
|
} |
||||||
|
|
||||||
|
func mapForm(ptr interface{}, form map[string][]string) error { |
||||||
|
sinfo := scache.get(reflect.TypeOf(ptr)) |
||||||
|
val := reflect.ValueOf(ptr).Elem() |
||||||
|
for i, fd := range sinfo.field { |
||||||
|
typeField := fd.tp |
||||||
|
structField := val.Field(i) |
||||||
|
if !structField.CanSet() { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
structFieldKind := structField.Kind() |
||||||
|
inputFieldName := fd.name |
||||||
|
if inputFieldName == "" { |
||||||
|
inputFieldName = typeField.Name |
||||||
|
|
||||||
|
// if "form" tag is nil, we inspect if the field is a struct.
|
||||||
|
// this would not make sense for JSON parsing but it does for a form
|
||||||
|
// since data is flatten
|
||||||
|
if structFieldKind == reflect.Struct { |
||||||
|
err := mapForm(structField.Addr().Interface(), form) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
} |
||||||
|
inputValue, exists := form[inputFieldName] |
||||||
|
if !exists { |
||||||
|
// Set the field as default value when the input value is not exist
|
||||||
|
if fd.hasDefault { |
||||||
|
structField.Set(fd.defaultValue) |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
// Set the field as default value when the input value is empty
|
||||||
|
if fd.hasDefault && inputValue[0] == "" { |
||||||
|
structField.Set(fd.defaultValue) |
||||||
|
continue |
||||||
|
} |
||||||
|
if _, isTime := structField.Interface().(time.Time); isTime { |
||||||
|
if err := setTimeField(inputValue[0], typeField, structField); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error { |
||||||
|
switch valueKind { |
||||||
|
case reflect.Int: |
||||||
|
return setIntField(val[0], 0, structField) |
||||||
|
case reflect.Int8: |
||||||
|
return setIntField(val[0], 8, structField) |
||||||
|
case reflect.Int16: |
||||||
|
return setIntField(val[0], 16, structField) |
||||||
|
case reflect.Int32: |
||||||
|
return setIntField(val[0], 32, structField) |
||||||
|
case reflect.Int64: |
||||||
|
return setIntField(val[0], 64, structField) |
||||||
|
case reflect.Uint: |
||||||
|
return setUintField(val[0], 0, structField) |
||||||
|
case reflect.Uint8: |
||||||
|
return setUintField(val[0], 8, structField) |
||||||
|
case reflect.Uint16: |
||||||
|
return setUintField(val[0], 16, structField) |
||||||
|
case reflect.Uint32: |
||||||
|
return setUintField(val[0], 32, structField) |
||||||
|
case reflect.Uint64: |
||||||
|
return setUintField(val[0], 64, structField) |
||||||
|
case reflect.Bool: |
||||||
|
return setBoolField(val[0], structField) |
||||||
|
case reflect.Float32: |
||||||
|
return setFloatField(val[0], 32, structField) |
||||||
|
case reflect.Float64: |
||||||
|
return setFloatField(val[0], 64, structField) |
||||||
|
case reflect.String: |
||||||
|
structField.SetString(val[0]) |
||||||
|
case reflect.Slice: |
||||||
|
if option.Contains("split") { |
||||||
|
val = strings.Split(val[0], ",") |
||||||
|
} |
||||||
|
filtered := filterEmpty(val) |
||||||
|
switch structField.Type().Elem().Kind() { |
||||||
|
case reflect.Int64: |
||||||
|
valSli := make([]int64, 0, len(filtered)) |
||||||
|
for i := 0; i < len(filtered); i++ { |
||||||
|
d, err := strconv.ParseInt(filtered[i], 10, 64) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
valSli = append(valSli, d) |
||||||
|
} |
||||||
|
structField.Set(reflect.ValueOf(valSli)) |
||||||
|
case reflect.String: |
||||||
|
valSli := make([]string, 0, len(filtered)) |
||||||
|
for i := 0; i < len(filtered); i++ { |
||||||
|
valSli = append(valSli, filtered[i]) |
||||||
|
} |
||||||
|
structField.Set(reflect.ValueOf(valSli)) |
||||||
|
default: |
||||||
|
sliceOf := structField.Type().Elem().Kind() |
||||||
|
numElems := len(filtered) |
||||||
|
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered)) |
||||||
|
for i := 0; i < numElems; i++ { |
||||||
|
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
structField.Set(slice) |
||||||
|
} |
||||||
|
default: |
||||||
|
return errors.New("Unknown type") |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setIntField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0" |
||||||
|
} |
||||||
|
intVal, err := strconv.ParseInt(val, 10, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetInt(intVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setUintField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0" |
||||||
|
} |
||||||
|
uintVal, err := strconv.ParseUint(val, 10, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetUint(uintVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setBoolField(val string, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "false" |
||||||
|
} |
||||||
|
boolVal, err := strconv.ParseBool(val) |
||||||
|
if err == nil { |
||||||
|
field.SetBool(boolVal) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func setFloatField(val string, bitSize int, field reflect.Value) error { |
||||||
|
if val == "" { |
||||||
|
val = "0.0" |
||||||
|
} |
||||||
|
floatVal, err := strconv.ParseFloat(val, bitSize) |
||||||
|
if err == nil { |
||||||
|
field.SetFloat(floatVal) |
||||||
|
} |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { |
||||||
|
timeFormat := structField.Tag.Get("time_format") |
||||||
|
if timeFormat == "" { |
||||||
|
return errors.New("Blank time format") |
||||||
|
} |
||||||
|
|
||||||
|
if val == "" { |
||||||
|
value.Set(reflect.ValueOf(time.Time{})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
l := time.Local |
||||||
|
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { |
||||||
|
l = time.UTC |
||||||
|
} |
||||||
|
|
||||||
|
if locTag := structField.Tag.Get("time_location"); locTag != "" { |
||||||
|
loc, err := time.LoadLocation(locTag) |
||||||
|
if err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
l = loc |
||||||
|
} |
||||||
|
|
||||||
|
t, err := time.ParseInLocation(timeFormat, val, l) |
||||||
|
if err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
|
||||||
|
value.Set(reflect.ValueOf(t)) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func filterEmpty(val []string) []string { |
||||||
|
filtered := make([]string, 0, len(val)) |
||||||
|
for _, v := range val { |
||||||
|
if v != "" { |
||||||
|
filtered = append(filtered, v) |
||||||
|
} |
||||||
|
} |
||||||
|
return filtered |
||||||
|
} |
@ -0,0 +1,22 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
type jsonBinding struct{} |
||||||
|
|
||||||
|
func (jsonBinding) Name() string { |
||||||
|
return "json" |
||||||
|
} |
||||||
|
|
||||||
|
func (jsonBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
decoder := json.NewDecoder(req.Body) |
||||||
|
if err := decoder.Decode(obj); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,19 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
type queryBinding struct{} |
||||||
|
|
||||||
|
func (queryBinding) Name() string { |
||||||
|
return "query" |
||||||
|
} |
||||||
|
|
||||||
|
func (queryBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
values := req.URL.Query() |
||||||
|
if err := mapForm(obj, values); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,44 @@ |
|||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
// tagOptions is the string following a comma in a struct field's "json"
|
||||||
|
// tag, or the empty string. It does not include the leading comma.
|
||||||
|
type tagOptions string |
||||||
|
|
||||||
|
// parseTag splits a struct field's json tag into its name and
|
||||||
|
// comma-separated options.
|
||||||
|
func parseTag(tag string) (string, tagOptions) { |
||||||
|
if idx := strings.Index(tag, ","); idx != -1 { |
||||||
|
return tag[:idx], tagOptions(tag[idx+1:]) |
||||||
|
} |
||||||
|
return tag, tagOptions("") |
||||||
|
} |
||||||
|
|
||||||
|
// Contains reports whether a comma-separated list of options
|
||||||
|
// contains a particular substr flag. substr must be surrounded by a
|
||||||
|
// string boundary or commas.
|
||||||
|
func (o tagOptions) Contains(optionName string) bool { |
||||||
|
if len(o) == 0 { |
||||||
|
return false |
||||||
|
} |
||||||
|
s := string(o) |
||||||
|
for s != "" { |
||||||
|
var next string |
||||||
|
i := strings.Index(s, ",") |
||||||
|
if i >= 0 { |
||||||
|
s, next = s[:i], s[i+1:] |
||||||
|
} |
||||||
|
if s == optionName { |
||||||
|
return true |
||||||
|
} |
||||||
|
s = next |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
@ -0,0 +1,209 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
) |
||||||
|
|
||||||
|
type testInterface interface { |
||||||
|
String() string |
||||||
|
} |
||||||
|
|
||||||
|
type substructNoValidation struct { |
||||||
|
IString string |
||||||
|
IInt int |
||||||
|
} |
||||||
|
|
||||||
|
type mapNoValidationSub map[string]substructNoValidation |
||||||
|
|
||||||
|
type structNoValidationValues struct { |
||||||
|
substructNoValidation |
||||||
|
|
||||||
|
Boolean bool |
||||||
|
|
||||||
|
Uinteger uint |
||||||
|
Integer int |
||||||
|
Integer8 int8 |
||||||
|
Integer16 int16 |
||||||
|
Integer32 int32 |
||||||
|
Integer64 int64 |
||||||
|
Uinteger8 uint8 |
||||||
|
Uinteger16 uint16 |
||||||
|
Uinteger32 uint32 |
||||||
|
Uinteger64 uint64 |
||||||
|
|
||||||
|
Float32 float32 |
||||||
|
Float64 float64 |
||||||
|
|
||||||
|
String string |
||||||
|
|
||||||
|
Date time.Time |
||||||
|
|
||||||
|
Struct substructNoValidation |
||||||
|
InlinedStruct struct { |
||||||
|
String []string |
||||||
|
Integer int |
||||||
|
} |
||||||
|
|
||||||
|
IntSlice []int |
||||||
|
IntPointerSlice []*int |
||||||
|
StructPointerSlice []*substructNoValidation |
||||||
|
StructSlice []substructNoValidation |
||||||
|
InterfaceSlice []testInterface |
||||||
|
|
||||||
|
UniversalInterface interface{} |
||||||
|
CustomInterface testInterface |
||||||
|
|
||||||
|
FloatMap map[string]float32 |
||||||
|
StructMap mapNoValidationSub |
||||||
|
} |
||||||
|
|
||||||
|
func createNoValidationValues() structNoValidationValues { |
||||||
|
integer := 1 |
||||||
|
s := structNoValidationValues{ |
||||||
|
Boolean: true, |
||||||
|
Uinteger: 1 << 29, |
||||||
|
Integer: -10000, |
||||||
|
Integer8: 120, |
||||||
|
Integer16: -20000, |
||||||
|
Integer32: 1 << 29, |
||||||
|
Integer64: 1 << 61, |
||||||
|
Uinteger8: 250, |
||||||
|
Uinteger16: 50000, |
||||||
|
Uinteger32: 1 << 31, |
||||||
|
Uinteger64: 1 << 62, |
||||||
|
Float32: 123.456, |
||||||
|
Float64: 123.456789, |
||||||
|
String: "text", |
||||||
|
Date: time.Time{}, |
||||||
|
CustomInterface: &bytes.Buffer{}, |
||||||
|
Struct: substructNoValidation{}, |
||||||
|
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3}, |
||||||
|
IntPointerSlice: []*int{&integer}, |
||||||
|
StructSlice: []substructNoValidation{}, |
||||||
|
UniversalInterface: 1.2, |
||||||
|
FloatMap: map[string]float32{ |
||||||
|
"foo": 1.23, |
||||||
|
"bar": 232.323, |
||||||
|
}, |
||||||
|
StructMap: mapNoValidationSub{ |
||||||
|
"foo": substructNoValidation{}, |
||||||
|
"bar": substructNoValidation{}, |
||||||
|
}, |
||||||
|
// StructPointerSlice []noValidationSub
|
||||||
|
// InterfaceSlice []testInterface
|
||||||
|
} |
||||||
|
s.InlinedStruct.Integer = 1000 |
||||||
|
s.InlinedStruct.String = []string{"first", "second"} |
||||||
|
s.IString = "substring" |
||||||
|
s.IInt = 987654 |
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateNoValidationValues(t *testing.T) { |
||||||
|
origin := createNoValidationValues() |
||||||
|
test := createNoValidationValues() |
||||||
|
empty := structNoValidationValues{} |
||||||
|
|
||||||
|
assert.Nil(t, validate(test)) |
||||||
|
assert.Nil(t, validate(&test)) |
||||||
|
assert.Nil(t, validate(empty)) |
||||||
|
assert.Nil(t, validate(&empty)) |
||||||
|
|
||||||
|
assert.Equal(t, origin, test) |
||||||
|
} |
||||||
|
|
||||||
|
type structNoValidationPointer struct { |
||||||
|
// substructNoValidation
|
||||||
|
|
||||||
|
Boolean bool |
||||||
|
|
||||||
|
Uinteger *uint |
||||||
|
Integer *int |
||||||
|
Integer8 *int8 |
||||||
|
Integer16 *int16 |
||||||
|
Integer32 *int32 |
||||||
|
Integer64 *int64 |
||||||
|
Uinteger8 *uint8 |
||||||
|
Uinteger16 *uint16 |
||||||
|
Uinteger32 *uint32 |
||||||
|
Uinteger64 *uint64 |
||||||
|
|
||||||
|
Float32 *float32 |
||||||
|
Float64 *float64 |
||||||
|
|
||||||
|
String *string |
||||||
|
|
||||||
|
Date *time.Time |
||||||
|
|
||||||
|
Struct *substructNoValidation |
||||||
|
|
||||||
|
IntSlice *[]int |
||||||
|
IntPointerSlice *[]*int |
||||||
|
StructPointerSlice *[]*substructNoValidation |
||||||
|
StructSlice *[]substructNoValidation |
||||||
|
InterfaceSlice *[]testInterface |
||||||
|
|
||||||
|
FloatMap *map[string]float32 |
||||||
|
StructMap *mapNoValidationSub |
||||||
|
} |
||||||
|
|
||||||
|
func TestValidateNoValidationPointers(t *testing.T) { |
||||||
|
//origin := createNoValidation_values()
|
||||||
|
//test := createNoValidation_values()
|
||||||
|
empty := structNoValidationPointer{} |
||||||
|
|
||||||
|
//assert.Nil(t, validate(test))
|
||||||
|
//assert.Nil(t, validate(&test))
|
||||||
|
assert.Nil(t, validate(empty)) |
||||||
|
assert.Nil(t, validate(&empty)) |
||||||
|
|
||||||
|
//assert.Equal(t, origin, test)
|
||||||
|
} |
||||||
|
|
||||||
|
type Object map[string]interface{} |
||||||
|
|
||||||
|
func TestValidatePrimitives(t *testing.T) { |
||||||
|
obj := Object{"foo": "bar", "bar": 1} |
||||||
|
assert.NoError(t, validate(obj)) |
||||||
|
assert.NoError(t, validate(&obj)) |
||||||
|
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1}) |
||||||
|
|
||||||
|
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}} |
||||||
|
assert.NoError(t, validate(obj2)) |
||||||
|
assert.NoError(t, validate(&obj2)) |
||||||
|
|
||||||
|
nu := 10 |
||||||
|
assert.NoError(t, validate(nu)) |
||||||
|
assert.NoError(t, validate(&nu)) |
||||||
|
assert.Equal(t, nu, 10) |
||||||
|
|
||||||
|
str := "value" |
||||||
|
assert.NoError(t, validate(str)) |
||||||
|
assert.NoError(t, validate(&str)) |
||||||
|
assert.Equal(t, str, "value") |
||||||
|
} |
||||||
|
|
||||||
|
// structCustomValidation is a helper struct we use to check that
|
||||||
|
// custom validation can be registered on it.
|
||||||
|
// The `notone` binding directive is for custom validation and registered later.
|
||||||
|
// type structCustomValidation struct {
|
||||||
|
// Integer int `binding:"notone"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
// notOne is a custom validator meant to be used with `validator.v8` library.
|
||||||
|
// The method signature for `v9` is significantly different and this function
|
||||||
|
// would need to be changed for tests to pass after upgrade.
|
||||||
|
// See https://github.com/gin-gonic/gin/pull/1015.
|
||||||
|
// func notOne(
|
||||||
|
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
|
||||||
|
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
|
||||||
|
// ) bool {
|
||||||
|
// if val, ok := field.Interface().(int); ok {
|
||||||
|
// return val != 1
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
@ -0,0 +1,22 @@ |
|||||||
|
package binding |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/xml" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
type xmlBinding struct{} |
||||||
|
|
||||||
|
func (xmlBinding) Name() string { |
||||||
|
return "xml" |
||||||
|
} |
||||||
|
|
||||||
|
func (xmlBinding) Bind(req *http.Request, obj interface{}) error { |
||||||
|
decoder := xml.NewDecoder(req.Body) |
||||||
|
if err := decoder.Decode(obj); err != nil { |
||||||
|
return errors.WithStack(err) |
||||||
|
} |
||||||
|
return validate(obj) |
||||||
|
} |
@ -0,0 +1,306 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"math" |
||||||
|
"net/http" |
||||||
|
"strconv" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/ecode" |
||||||
|
"github.com/bilibili/Kratos/pkg/net/http/blademaster/binding" |
||||||
|
"github.com/bilibili/Kratos/pkg/net/http/blademaster/render" |
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
"github.com/gogo/protobuf/types" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
_abortIndex int8 = math.MaxInt8 / 2 |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_openParen = []byte("(") |
||||||
|
_closeParen = []byte(")") |
||||||
|
) |
||||||
|
|
||||||
|
// Context is the most important part. It allows us to pass variables between
|
||||||
|
// middleware, manage the flow, validate the JSON of a request and render a
|
||||||
|
// JSON response for example.
|
||||||
|
type Context struct { |
||||||
|
context.Context |
||||||
|
|
||||||
|
Request *http.Request |
||||||
|
Writer http.ResponseWriter |
||||||
|
|
||||||
|
// flow control
|
||||||
|
index int8 |
||||||
|
handlers []HandlerFunc |
||||||
|
|
||||||
|
// Keys is a key/value pair exclusively for the context of each request.
|
||||||
|
Keys map[string]interface{} |
||||||
|
|
||||||
|
Error error |
||||||
|
|
||||||
|
method string |
||||||
|
engine *Engine |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/*********** FLOW CONTROL ***********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// Next should be used only inside middleware.
|
||||||
|
// It executes the pending handlers in the chain inside the calling handler.
|
||||||
|
// See example in godoc.
|
||||||
|
func (c *Context) Next() { |
||||||
|
c.index++ |
||||||
|
s := int8(len(c.handlers)) |
||||||
|
for ; c.index < s; c.index++ { |
||||||
|
// only check method on last handler, otherwise middlewares
|
||||||
|
// will never be effected if request method is not matched
|
||||||
|
if c.index == s-1 && c.method != c.Request.Method { |
||||||
|
code := http.StatusMethodNotAllowed |
||||||
|
c.Error = ecode.MethodNotAllowed |
||||||
|
http.Error(c.Writer, http.StatusText(code), code) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
c.handlers[c.index](c) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
|
||||||
|
// Let's say you have an authorization middleware that validates that the current request is authorized.
|
||||||
|
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
|
||||||
|
// for this request are not called.
|
||||||
|
func (c *Context) Abort() { |
||||||
|
c.index = _abortIndex |
||||||
|
} |
||||||
|
|
||||||
|
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
|
||||||
|
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
|
||||||
|
func (c *Context) AbortWithStatus(code int) { |
||||||
|
c.Status(code) |
||||||
|
c.Abort() |
||||||
|
} |
||||||
|
|
||||||
|
// IsAborted returns true if the current context was aborted.
|
||||||
|
func (c *Context) IsAborted() bool { |
||||||
|
return c.index >= _abortIndex |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/******** METADATA MANAGEMENT********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// Set is used to store a new key/value pair exclusively for this context.
|
||||||
|
// It also lazy initializes c.Keys if it was not used previously.
|
||||||
|
func (c *Context) Set(key string, value interface{}) { |
||||||
|
if c.Keys == nil { |
||||||
|
c.Keys = make(map[string]interface{}) |
||||||
|
} |
||||||
|
c.Keys[key] = value |
||||||
|
} |
||||||
|
|
||||||
|
// Get returns the value for the given key, ie: (value, true).
|
||||||
|
// If the value does not exists it returns (nil, false)
|
||||||
|
func (c *Context) Get(key string) (value interface{}, exists bool) { |
||||||
|
value, exists = c.Keys[key] |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
/************************************/ |
||||||
|
/******** RESPONSE RENDERING ********/ |
||||||
|
/************************************/ |
||||||
|
|
||||||
|
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
|
||||||
|
func bodyAllowedForStatus(status int) bool { |
||||||
|
switch { |
||||||
|
case status >= 100 && status <= 199: |
||||||
|
return false |
||||||
|
case status == 204: |
||||||
|
return false |
||||||
|
case status == 304: |
||||||
|
return false |
||||||
|
} |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
// Status sets the HTTP response code.
|
||||||
|
func (c *Context) Status(code int) { |
||||||
|
c.Writer.WriteHeader(code) |
||||||
|
} |
||||||
|
|
||||||
|
// Render http response with http code by a render instance.
|
||||||
|
func (c *Context) Render(code int, r render.Render) { |
||||||
|
r.WriteContentType(c.Writer) |
||||||
|
if code > 0 { |
||||||
|
c.Status(code) |
||||||
|
} |
||||||
|
|
||||||
|
if !bodyAllowedForStatus(code) { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
params := c.Request.Form |
||||||
|
|
||||||
|
cb := params.Get("callback") |
||||||
|
jsonp := cb != "" && params.Get("jsonp") == "jsonp" |
||||||
|
if jsonp { |
||||||
|
c.Writer.Write([]byte(cb)) |
||||||
|
c.Writer.Write(_openParen) |
||||||
|
} |
||||||
|
|
||||||
|
if err := r.Render(c.Writer); err != nil { |
||||||
|
c.Error = err |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if jsonp { |
||||||
|
if _, err := c.Writer.Write(_closeParen); err != nil { |
||||||
|
c.Error = errors.WithStack(err) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// JSON serializes the given struct as JSON into the response body.
|
||||||
|
// It also sets the Content-Type as "application/json".
|
||||||
|
func (c *Context) JSON(data interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.JSON{ |
||||||
|
Code: bcode.Code(), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// JSONMap serializes the given map as map JSON into the response body.
|
||||||
|
// It also sets the Content-Type as "application/json".
|
||||||
|
func (c *Context) JSONMap(data map[string]interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
data["code"] = bcode.Code() |
||||||
|
if _, ok := data["message"]; !ok { |
||||||
|
data["message"] = bcode.Message() |
||||||
|
} |
||||||
|
c.Render(code, render.MapJSON(data)) |
||||||
|
} |
||||||
|
|
||||||
|
// XML serializes the given struct as XML into the response body.
|
||||||
|
// It also sets the Content-Type as "application/xml".
|
||||||
|
func (c *Context) XML(data interface{}, err error) { |
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
// TODO app allow 5xx?
|
||||||
|
/* |
||||||
|
if bcode.Code() == -500 { |
||||||
|
code = http.StatusServiceUnavailable |
||||||
|
} |
||||||
|
*/ |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.XML{ |
||||||
|
Code: bcode.Code(), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// Protobuf serializes the given struct as PB into the response body.
|
||||||
|
// It also sets the ContentType as "application/x-protobuf".
|
||||||
|
func (c *Context) Protobuf(data proto.Message, err error) { |
||||||
|
var ( |
||||||
|
bytes []byte |
||||||
|
) |
||||||
|
|
||||||
|
code := http.StatusOK |
||||||
|
c.Error = err |
||||||
|
bcode := ecode.Cause(err) |
||||||
|
|
||||||
|
any := new(types.Any) |
||||||
|
if data != nil { |
||||||
|
if bytes, err = proto.Marshal(data); err != nil { |
||||||
|
c.Error = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data) |
||||||
|
any.Value = bytes |
||||||
|
} |
||||||
|
writeStatusCode(c.Writer, bcode.Code()) |
||||||
|
c.Render(code, render.PB{ |
||||||
|
Code: int64(bcode.Code()), |
||||||
|
Message: bcode.Message(), |
||||||
|
Data: any, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// Bytes writes some data into the body stream and updates the HTTP code.
|
||||||
|
func (c *Context) Bytes(code int, contentType string, data ...[]byte) { |
||||||
|
c.Render(code, render.Data{ |
||||||
|
ContentType: contentType, |
||||||
|
Data: data, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// String writes the given string into the response body.
|
||||||
|
func (c *Context) String(code int, format string, values ...interface{}) { |
||||||
|
c.Render(code, render.String{Format: format, Data: values}) |
||||||
|
} |
||||||
|
|
||||||
|
// Redirect returns a HTTP redirect to the specific location.
|
||||||
|
func (c *Context) Redirect(code int, location string) { |
||||||
|
c.Render(-1, render.Redirect{ |
||||||
|
Code: code, |
||||||
|
Location: location, |
||||||
|
Request: c.Request, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// BindWith bind req arg with parser.
|
||||||
|
func (c *Context) BindWith(obj interface{}, b binding.Binding) error { |
||||||
|
return c.mustBindWith(obj, b) |
||||||
|
} |
||||||
|
|
||||||
|
// Bind bind req arg with defult form binding.
|
||||||
|
func (c *Context) Bind(obj interface{}) error { |
||||||
|
return c.mustBindWith(obj, binding.Form) |
||||||
|
} |
||||||
|
|
||||||
|
// mustBindWith binds the passed struct pointer using the specified binding engine.
|
||||||
|
// It will abort the request with HTTP 400 if any error ocurrs.
|
||||||
|
// See the binding package.
|
||||||
|
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) { |
||||||
|
if err = b.Bind(c.Request, obj); err != nil { |
||||||
|
c.Error = ecode.RequestErr |
||||||
|
c.Render(http.StatusOK, render.JSON{ |
||||||
|
Code: ecode.RequestErr.Code(), |
||||||
|
Message: err.Error(), |
||||||
|
Data: nil, |
||||||
|
}) |
||||||
|
c.Abort() |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func writeStatusCode(w http.ResponseWriter, ecode int) { |
||||||
|
header := w.Header() |
||||||
|
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10)) |
||||||
|
} |
@ -0,0 +1,249 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/log" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// CORSConfig represents all available options for the middleware.
|
||||||
|
type CORSConfig struct { |
||||||
|
AllowAllOrigins bool |
||||||
|
|
||||||
|
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||||
|
// If the special "*" value is present in the list, all origins will be allowed.
|
||||||
|
// Default value is []
|
||||||
|
AllowOrigins []string |
||||||
|
|
||||||
|
// AllowOriginFunc is a custom function to validate the origin. It take the origin
|
||||||
|
// as argument and returns true if allowed or false otherwise. If this option is
|
||||||
|
// set, the content of AllowedOrigins is ignored.
|
||||||
|
AllowOriginFunc func(origin string) bool |
||||||
|
|
||||||
|
// AllowedMethods is a list of methods the client is allowed to use with
|
||||||
|
// cross-domain requests. Default value is simple methods (GET and POST)
|
||||||
|
AllowMethods []string |
||||||
|
|
||||||
|
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||||
|
// cross-domain requests.
|
||||||
|
AllowHeaders []string |
||||||
|
|
||||||
|
// AllowCredentials indicates whether the request can include user credentials like
|
||||||
|
// cookies, HTTP authentication or client side SSL certificates.
|
||||||
|
AllowCredentials bool |
||||||
|
|
||||||
|
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||||
|
// API specification
|
||||||
|
ExposeHeaders []string |
||||||
|
|
||||||
|
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||||
|
// can be cached
|
||||||
|
MaxAge time.Duration |
||||||
|
} |
||||||
|
|
||||||
|
type cors struct { |
||||||
|
allowAllOrigins bool |
||||||
|
allowCredentials bool |
||||||
|
allowOriginFunc func(string) bool |
||||||
|
allowOrigins []string |
||||||
|
normalHeaders http.Header |
||||||
|
preflightHeaders http.Header |
||||||
|
} |
||||||
|
|
||||||
|
type converter func(string) string |
||||||
|
|
||||||
|
// Validate is check configuration of user defined.
|
||||||
|
func (c *CORSConfig) Validate() error { |
||||||
|
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { |
||||||
|
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed") |
||||||
|
} |
||||||
|
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { |
||||||
|
return errors.New("conflict settings: all origins disabled") |
||||||
|
} |
||||||
|
for _, origin := range c.AllowOrigins { |
||||||
|
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") { |
||||||
|
return errors.New("bad origin: origins must either be '*' or include http:// or https://") |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// CORS returns the location middleware with default configuration.
|
||||||
|
func CORS(allowOriginHosts []string) HandlerFunc { |
||||||
|
config := &CORSConfig{ |
||||||
|
AllowMethods: []string{"GET", "POST"}, |
||||||
|
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"}, |
||||||
|
AllowCredentials: true, |
||||||
|
MaxAge: time.Duration(0), |
||||||
|
AllowOriginFunc: func(origin string) bool { |
||||||
|
for _, host := range allowOriginHosts { |
||||||
|
if strings.HasSuffix(strings.ToLower(origin), host) { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
return false |
||||||
|
}, |
||||||
|
} |
||||||
|
return newCORS(config) |
||||||
|
} |
||||||
|
|
||||||
|
// newCORS returns the location middleware with user-defined custom configuration.
|
||||||
|
func newCORS(config *CORSConfig) HandlerFunc { |
||||||
|
if err := config.Validate(); err != nil { |
||||||
|
panic(err.Error()) |
||||||
|
} |
||||||
|
cors := &cors{ |
||||||
|
allowOriginFunc: config.AllowOriginFunc, |
||||||
|
allowAllOrigins: config.AllowAllOrigins, |
||||||
|
allowCredentials: config.AllowCredentials, |
||||||
|
allowOrigins: normalize(config.AllowOrigins), |
||||||
|
normalHeaders: generateNormalHeaders(config), |
||||||
|
preflightHeaders: generatePreflightHeaders(config), |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *Context) { |
||||||
|
cors.applyCORS(c) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) applyCORS(c *Context) { |
||||||
|
origin := c.Request.Header.Get("Origin") |
||||||
|
if len(origin) == 0 { |
||||||
|
// request is not a CORS request
|
||||||
|
return |
||||||
|
} |
||||||
|
if !cors.validateOrigin(origin) { |
||||||
|
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin) |
||||||
|
c.AbortWithStatus(http.StatusForbidden) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if c.Request.Method == "OPTIONS" { |
||||||
|
cors.handlePreflight(c) |
||||||
|
defer c.AbortWithStatus(200) |
||||||
|
} else { |
||||||
|
cors.handleNormal(c) |
||||||
|
} |
||||||
|
|
||||||
|
if !cors.allowAllOrigins { |
||||||
|
header := c.Writer.Header() |
||||||
|
header.Set("Access-Control-Allow-Origin", origin) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) validateOrigin(origin string) bool { |
||||||
|
if cors.allowAllOrigins { |
||||||
|
return true |
||||||
|
} |
||||||
|
for _, value := range cors.allowOrigins { |
||||||
|
if value == origin { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
if cors.allowOriginFunc != nil { |
||||||
|
return cors.allowOriginFunc(origin) |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) handlePreflight(c *Context) { |
||||||
|
header := c.Writer.Header() |
||||||
|
for key, value := range cors.preflightHeaders { |
||||||
|
header[key] = value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (cors *cors) handleNormal(c *Context) { |
||||||
|
header := c.Writer.Header() |
||||||
|
for key, value := range cors.normalHeaders { |
||||||
|
header[key] = value |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func generateNormalHeaders(c *CORSConfig) http.Header { |
||||||
|
headers := make(http.Header) |
||||||
|
if c.AllowCredentials { |
||||||
|
headers.Set("Access-Control-Allow-Credentials", "true") |
||||||
|
} |
||||||
|
|
||||||
|
// backport support for early browsers
|
||||||
|
if len(c.AllowMethods) > 0 { |
||||||
|
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||||
|
value := strings.Join(allowMethods, ",") |
||||||
|
headers.Set("Access-Control-Allow-Methods", value) |
||||||
|
} |
||||||
|
|
||||||
|
if len(c.ExposeHeaders) > 0 { |
||||||
|
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) |
||||||
|
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) |
||||||
|
} |
||||||
|
if c.AllowAllOrigins { |
||||||
|
headers.Set("Access-Control-Allow-Origin", "*") |
||||||
|
} else { |
||||||
|
headers.Set("Vary", "Origin") |
||||||
|
} |
||||||
|
return headers |
||||||
|
} |
||||||
|
|
||||||
|
func generatePreflightHeaders(c *CORSConfig) http.Header { |
||||||
|
headers := make(http.Header) |
||||||
|
if c.AllowCredentials { |
||||||
|
headers.Set("Access-Control-Allow-Credentials", "true") |
||||||
|
} |
||||||
|
if len(c.AllowMethods) > 0 { |
||||||
|
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) |
||||||
|
value := strings.Join(allowMethods, ",") |
||||||
|
headers.Set("Access-Control-Allow-Methods", value) |
||||||
|
} |
||||||
|
if len(c.AllowHeaders) > 0 { |
||||||
|
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) |
||||||
|
value := strings.Join(allowHeaders, ",") |
||||||
|
headers.Set("Access-Control-Allow-Headers", value) |
||||||
|
} |
||||||
|
if c.MaxAge > time.Duration(0) { |
||||||
|
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10) |
||||||
|
headers.Set("Access-Control-Max-Age", value) |
||||||
|
} |
||||||
|
if c.AllowAllOrigins { |
||||||
|
headers.Set("Access-Control-Allow-Origin", "*") |
||||||
|
} else { |
||||||
|
// Always set Vary headers
|
||||||
|
// see https://github.com/rs/cors/issues/10,
|
||||||
|
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||||
|
|
||||||
|
headers.Add("Vary", "Origin") |
||||||
|
headers.Add("Vary", "Access-Control-Request-Method") |
||||||
|
headers.Add("Vary", "Access-Control-Request-Headers") |
||||||
|
} |
||||||
|
return headers |
||||||
|
} |
||||||
|
|
||||||
|
func normalize(values []string) []string { |
||||||
|
if values == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
distinctMap := make(map[string]bool, len(values)) |
||||||
|
normalized := make([]string, 0, len(values)) |
||||||
|
for _, value := range values { |
||||||
|
value = strings.TrimSpace(value) |
||||||
|
value = strings.ToLower(value) |
||||||
|
if _, seen := distinctMap[value]; !seen { |
||||||
|
normalized = append(normalized, value) |
||||||
|
distinctMap[value] = true |
||||||
|
} |
||||||
|
} |
||||||
|
return normalized |
||||||
|
} |
||||||
|
|
||||||
|
func convert(s []string, c converter) []string { |
||||||
|
var out []string |
||||||
|
for _, i := range s { |
||||||
|
out = append(out, c(i)) |
||||||
|
} |
||||||
|
return out |
||||||
|
} |
@ -0,0 +1,69 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/url" |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
func matchHostSuffix(suffix string) func(*url.URL) bool { |
||||||
|
return func(uri *url.URL) bool { |
||||||
|
return strings.HasSuffix(strings.ToLower(uri.Host), suffix) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool { |
||||||
|
return func(uri *url.URL) bool { |
||||||
|
return pattern.MatchString(strings.ToLower(uri.String())) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// CSRF returns the csrf middleware to prevent invalid cross site request.
|
||||||
|
// Only referer is checked currently.
|
||||||
|
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc { |
||||||
|
validations := []func(*url.URL) bool{} |
||||||
|
|
||||||
|
addHostSuffix := func(suffix string) { |
||||||
|
validations = append(validations, matchHostSuffix(suffix)) |
||||||
|
} |
||||||
|
addPattern := func(pattern string) { |
||||||
|
validations = append(validations, matchPattern(regexp.MustCompile(pattern))) |
||||||
|
} |
||||||
|
|
||||||
|
for _, r := range allowHosts { |
||||||
|
addHostSuffix(r) |
||||||
|
} |
||||||
|
for _, p := range allowPattern { |
||||||
|
addPattern(p) |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *Context) { |
||||||
|
referer := c.Request.Header.Get("Referer") |
||||||
|
params := c.Request.Form |
||||||
|
cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "") |
||||||
|
if referer == "" { |
||||||
|
if !cross { |
||||||
|
return |
||||||
|
} |
||||||
|
log.V(5).Info("The request's Referer header is empty.") |
||||||
|
c.AbortWithStatus(403) |
||||||
|
return |
||||||
|
} |
||||||
|
illegal := true |
||||||
|
if uri, err := url.Parse(referer); err == nil && uri.Host != "" { |
||||||
|
for _, validate := range validations { |
||||||
|
if validate(uri) { |
||||||
|
illegal = false |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
if illegal { |
||||||
|
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer) |
||||||
|
c.AbortWithStatus(403) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,69 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"strconv" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/ecode" |
||||||
|
"github.com/bilibili/Kratos/pkg/log" |
||||||
|
"github.com/bilibili/Kratos/pkg/net/metadata" |
||||||
|
) |
||||||
|
|
||||||
|
// Logger is logger middleware
|
||||||
|
func Logger() HandlerFunc { |
||||||
|
const noUser = "no_user" |
||||||
|
return func(c *Context) { |
||||||
|
now := time.Now() |
||||||
|
ip := metadata.String(c, metadata.RemoteIP) |
||||||
|
req := c.Request |
||||||
|
path := req.URL.Path |
||||||
|
params := req.Form |
||||||
|
var quota float64 |
||||||
|
if deadline, ok := c.Context.Deadline(); ok { |
||||||
|
quota = time.Until(deadline).Seconds() |
||||||
|
} |
||||||
|
|
||||||
|
c.Next() |
||||||
|
|
||||||
|
err := c.Error |
||||||
|
cerr := ecode.Cause(err) |
||||||
|
dt := time.Since(now) |
||||||
|
caller := metadata.String(c, metadata.Caller) |
||||||
|
if caller == "" { |
||||||
|
caller = noUser |
||||||
|
} |
||||||
|
|
||||||
|
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10)) |
||||||
|
stats.Timing(caller, int64(dt/time.Millisecond), path[1:]) |
||||||
|
|
||||||
|
lf := log.Infov |
||||||
|
errmsg := "" |
||||||
|
isSlow := dt >= (time.Millisecond * 500) |
||||||
|
if err != nil { |
||||||
|
errmsg = err.Error() |
||||||
|
lf = log.Errorv |
||||||
|
if cerr.Code() > 0 { |
||||||
|
lf = log.Warnv |
||||||
|
} |
||||||
|
} else { |
||||||
|
if isSlow { |
||||||
|
lf = log.Warnv |
||||||
|
} |
||||||
|
} |
||||||
|
lf(c, |
||||||
|
log.KVString("method", req.Method), |
||||||
|
log.KVString("ip", ip), |
||||||
|
log.KVString("user", caller), |
||||||
|
log.KVString("path", path), |
||||||
|
log.KVString("params", params.Encode()), |
||||||
|
log.KVInt("ret", cerr.Code()), |
||||||
|
log.KVString("msg", cerr.Message()), |
||||||
|
log.KVString("stack", fmt.Sprintf("%+v", err)), |
||||||
|
log.KVString("err", errmsg), |
||||||
|
log.KVFloat64("timeout_quota", quota), |
||||||
|
log.KVFloat64("ts", dt.Seconds()), |
||||||
|
log.KVString("source", "http-access-log"), |
||||||
|
) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,46 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
"net/http" |
||||||
|
"net/http/pprof" |
||||||
|
"os" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/conf/dsn" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_perfOnce sync.Once |
||||||
|
_perfDSN string |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
v := os.Getenv("HTTP_PERF") |
||||||
|
if v == "" { |
||||||
|
v = "tcp://0.0.0.0:2333" |
||||||
|
} |
||||||
|
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.") |
||||||
|
} |
||||||
|
|
||||||
|
func startPerf() { |
||||||
|
_perfOnce.Do(func() { |
||||||
|
mux := http.NewServeMux() |
||||||
|
mux.HandleFunc("/debug/pprof/", pprof.Index) |
||||||
|
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) |
||||||
|
mux.HandleFunc("/debug/pprof/profile", pprof.Profile) |
||||||
|
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) |
||||||
|
|
||||||
|
go func() { |
||||||
|
d, err := dsn.Parse(_perfDSN) |
||||||
|
if err != nil { |
||||||
|
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err)) |
||||||
|
} |
||||||
|
if err := http.ListenAndServe(d.Host, mux); err != nil { |
||||||
|
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err)) |
||||||
|
} |
||||||
|
}() |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,12 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||||
|
) |
||||||
|
|
||||||
|
func monitor() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
h := promhttp.Handler() |
||||||
|
h.ServeHTTP(c.Writer, c.Request) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,32 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"net/http/httputil" |
||||||
|
"os" |
||||||
|
"runtime" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/log" |
||||||
|
) |
||||||
|
|
||||||
|
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||||
|
func Recovery() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
defer func() { |
||||||
|
var rawReq []byte |
||||||
|
if err := recover(); err != nil { |
||||||
|
const size = 64 << 10 |
||||||
|
buf := make([]byte, size) |
||||||
|
buf = buf[:runtime.Stack(buf, false)] |
||||||
|
if c.Request != nil { |
||||||
|
rawReq, _ = httputil.DumpRequest(c.Request, false) |
||||||
|
} |
||||||
|
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf) |
||||||
|
fmt.Fprintf(os.Stderr, pl) |
||||||
|
log.Error(pl) |
||||||
|
c.AbortWithStatus(500) |
||||||
|
} |
||||||
|
}() |
||||||
|
c.Next() |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,30 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// Data common bytes struct.
|
||||||
|
type Data struct { |
||||||
|
ContentType string |
||||||
|
Data [][]byte |
||||||
|
} |
||||||
|
|
||||||
|
// Render (Data) writes data with custom ContentType.
|
||||||
|
func (r Data) Render(w http.ResponseWriter) (err error) { |
||||||
|
r.WriteContentType(w) |
||||||
|
for _, d := range r.Data { |
||||||
|
if _, err = w.Write(d); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType writes data with custom ContentType.
|
||||||
|
func (r Data) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, []string{r.ContentType}) |
||||||
|
} |
@ -0,0 +1,58 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var jsonContentType = []string{"application/json; charset=utf-8"} |
||||||
|
|
||||||
|
// JSON common json struct.
|
||||||
|
type JSON struct { |
||||||
|
Code int `json:"code"` |
||||||
|
Message string `json:"message"` |
||||||
|
TTL int `json:"ttl"` |
||||||
|
Data interface{} `json:"data,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) { |
||||||
|
var jsonBytes []byte |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
if jsonBytes, err = json.Marshal(obj); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
if _, err = w.Write(jsonBytes); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// Render (JSON) writes data with json ContentType.
|
||||||
|
func (r JSON) Render(w http.ResponseWriter) error { |
||||||
|
// FIXME(zhoujiahui): the TTL field will be configurable in the future
|
||||||
|
if r.TTL <= 0 { |
||||||
|
r.TTL = 1 |
||||||
|
} |
||||||
|
return writeJSON(w, r) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write json ContentType.
|
||||||
|
func (r JSON) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
} |
||||||
|
|
||||||
|
// MapJSON common map json struct.
|
||||||
|
type MapJSON map[string]interface{} |
||||||
|
|
||||||
|
// Render (MapJSON) writes data with json ContentType.
|
||||||
|
func (m MapJSON) Render(w http.ResponseWriter) error { |
||||||
|
return writeJSON(w, m) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write json ContentType.
|
||||||
|
func (m MapJSON) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, jsonContentType) |
||||||
|
} |
@ -0,0 +1,38 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/gogo/protobuf/proto" |
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var pbContentType = []string{"application/x-protobuf"} |
||||||
|
|
||||||
|
// Render (PB) writes data with protobuf ContentType.
|
||||||
|
func (r PB) Render(w http.ResponseWriter) error { |
||||||
|
if r.TTL <= 0 { |
||||||
|
r.TTL = 1 |
||||||
|
} |
||||||
|
return writePB(w, r) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write protobuf ContentType.
|
||||||
|
func (r PB) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, pbContentType) |
||||||
|
} |
||||||
|
|
||||||
|
func writePB(w http.ResponseWriter, obj PB) (err error) { |
||||||
|
var pbBytes []byte |
||||||
|
writeContentType(w, pbContentType) |
||||||
|
|
||||||
|
if pbBytes, err = proto.Marshal(&obj); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
if _, err = w.Write(pbBytes); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,26 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// Redirect render for redirect to specified location.
|
||||||
|
type Redirect struct { |
||||||
|
Code int |
||||||
|
Request *http.Request |
||||||
|
Location string |
||||||
|
} |
||||||
|
|
||||||
|
// Render (Redirect) redirect to specified location.
|
||||||
|
func (r Redirect) Render(w http.ResponseWriter) error { |
||||||
|
if (r.Code < 300 || r.Code > 308) && r.Code != 201 { |
||||||
|
return errors.Errorf("Cannot redirect with status code %d", r.Code) |
||||||
|
} |
||||||
|
http.Redirect(w, r.Request, r.Location, r.Code) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType noneContentType.
|
||||||
|
func (r Redirect) WriteContentType(http.ResponseWriter) {} |
@ -0,0 +1,30 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
// Render http reponse render.
|
||||||
|
type Render interface { |
||||||
|
// Render render it to http response writer.
|
||||||
|
Render(http.ResponseWriter) error |
||||||
|
// WriteContentType write content-type to http response writer.
|
||||||
|
WriteContentType(w http.ResponseWriter) |
||||||
|
} |
||||||
|
|
||||||
|
var ( |
||||||
|
_ Render = JSON{} |
||||||
|
_ Render = MapJSON{} |
||||||
|
_ Render = XML{} |
||||||
|
_ Render = String{} |
||||||
|
_ Render = Redirect{} |
||||||
|
_ Render = Data{} |
||||||
|
_ Render = PB{} |
||||||
|
) |
||||||
|
|
||||||
|
func writeContentType(w http.ResponseWriter, value []string) { |
||||||
|
header := w.Header() |
||||||
|
if val := header["Content-Type"]; len(val) == 0 { |
||||||
|
header["Content-Type"] = value |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,89 @@ |
|||||||
|
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||||
|
// source: pb.proto
|
||||||
|
|
||||||
|
/* |
||||||
|
Package render is a generated protocol buffer package. |
||||||
|
|
||||||
|
It is generated from these files: |
||||||
|
pb.proto |
||||||
|
|
||||||
|
It has these top-level messages: |
||||||
|
PB |
||||||
|
*/ |
||||||
|
package render |
||||||
|
|
||||||
|
import proto "github.com/gogo/protobuf/proto" |
||||||
|
import fmt "fmt" |
||||||
|
import math "math" |
||||||
|
import google_protobuf "github.com/gogo/protobuf/types" |
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal |
||||||
|
var _ = fmt.Errorf |
||||||
|
var _ = math.Inf |
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the proto package it is being compiled against.
|
||||||
|
// A compilation error at this line likely means your copy of the
|
||||||
|
// proto package needs to be updated.
|
||||||
|
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
|
||||||
|
|
||||||
|
type PB struct { |
||||||
|
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` |
||||||
|
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"` |
||||||
|
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"` |
||||||
|
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) Reset() { *m = PB{} } |
||||||
|
func (m *PB) String() string { return proto.CompactTextString(m) } |
||||||
|
func (*PB) ProtoMessage() {} |
||||||
|
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} } |
||||||
|
|
||||||
|
func (m *PB) GetCode() int64 { |
||||||
|
if m != nil { |
||||||
|
return m.Code |
||||||
|
} |
||||||
|
return 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetMessage() string { |
||||||
|
if m != nil { |
||||||
|
return m.Message |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetTTL() uint64 { |
||||||
|
if m != nil { |
||||||
|
return m.TTL |
||||||
|
} |
||||||
|
return 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (m *PB) GetData() *google_protobuf.Any { |
||||||
|
if m != nil { |
||||||
|
return m.Data |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func init() { |
||||||
|
proto.RegisterType((*PB)(nil), "render.PB") |
||||||
|
} |
||||||
|
|
||||||
|
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) } |
||||||
|
|
||||||
|
var fileDescriptorPb = []byte{ |
||||||
|
// 154 bytes of a gzipped FileDescriptorProto
|
||||||
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b, |
||||||
|
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf, |
||||||
|
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94, |
||||||
|
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30, |
||||||
|
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89, |
||||||
|
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48, |
||||||
|
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92, |
||||||
|
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66, |
||||||
|
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00, |
||||||
|
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00, |
||||||
|
} |
@ -0,0 +1,14 @@ |
|||||||
|
// use under command to generate pb.pb.go |
||||||
|
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto |
||||||
|
syntax = "proto3"; |
||||||
|
package render; |
||||||
|
|
||||||
|
import "google/protobuf/any.proto"; |
||||||
|
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; |
||||||
|
|
||||||
|
message PB { |
||||||
|
int64 Code = 1; |
||||||
|
string Message = 2; |
||||||
|
uint64 TTL = 3; |
||||||
|
google.protobuf.Any Data = 4; |
||||||
|
} |
@ -0,0 +1,40 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
var plainContentType = []string{"text/plain; charset=utf-8"} |
||||||
|
|
||||||
|
// String common string struct.
|
||||||
|
type String struct { |
||||||
|
Format string |
||||||
|
Data []interface{} |
||||||
|
} |
||||||
|
|
||||||
|
// Render (String) writes data with custom ContentType.
|
||||||
|
func (r String) Render(w http.ResponseWriter) error { |
||||||
|
return writeString(w, r.Format, r.Data) |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType writes string with text/plain ContentType.
|
||||||
|
func (r String) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, plainContentType) |
||||||
|
} |
||||||
|
|
||||||
|
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) { |
||||||
|
writeContentType(w, plainContentType) |
||||||
|
if len(data) > 0 { |
||||||
|
_, err = fmt.Fprintf(w, format, data...) |
||||||
|
} else { |
||||||
|
_, err = io.WriteString(w, format) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,31 @@ |
|||||||
|
package render |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/xml" |
||||||
|
"net/http" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
// XML common xml struct.
|
||||||
|
type XML struct { |
||||||
|
Code int |
||||||
|
Message string |
||||||
|
Data interface{} |
||||||
|
} |
||||||
|
|
||||||
|
var xmlContentType = []string{"application/xml; charset=utf-8"} |
||||||
|
|
||||||
|
// Render (XML) writes data with xml ContentType.
|
||||||
|
func (r XML) Render(w http.ResponseWriter) (err error) { |
||||||
|
r.WriteContentType(w) |
||||||
|
if err = xml.NewEncoder(w).Encode(r.Data); err != nil { |
||||||
|
err = errors.WithStack(err) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// WriteContentType write xml ContentType.
|
||||||
|
func (r XML) WriteContentType(w http.ResponseWriter) { |
||||||
|
writeContentType(w, xmlContentType) |
||||||
|
} |
@ -0,0 +1,166 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"regexp" |
||||||
|
) |
||||||
|
|
||||||
|
// IRouter http router framework interface.
|
||||||
|
type IRouter interface { |
||||||
|
IRoutes |
||||||
|
Group(string, ...HandlerFunc) *RouterGroup |
||||||
|
} |
||||||
|
|
||||||
|
// IRoutes http router interface.
|
||||||
|
type IRoutes interface { |
||||||
|
UseFunc(...HandlerFunc) IRoutes |
||||||
|
Use(...Handler) IRoutes |
||||||
|
|
||||||
|
Handle(string, string, ...HandlerFunc) IRoutes |
||||||
|
HEAD(string, ...HandlerFunc) IRoutes |
||||||
|
GET(string, ...HandlerFunc) IRoutes |
||||||
|
POST(string, ...HandlerFunc) IRoutes |
||||||
|
PUT(string, ...HandlerFunc) IRoutes |
||||||
|
DELETE(string, ...HandlerFunc) IRoutes |
||||||
|
} |
||||||
|
|
||||||
|
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
|
||||||
|
// and an array of handlers (middleware).
|
||||||
|
type RouterGroup struct { |
||||||
|
Handlers []HandlerFunc |
||||||
|
basePath string |
||||||
|
engine *Engine |
||||||
|
root bool |
||||||
|
baseConfig *MethodConfig |
||||||
|
} |
||||||
|
|
||||||
|
var _ IRouter = &RouterGroup{} |
||||||
|
|
||||||
|
// Use adds middleware to the group, see example code in doc.
|
||||||
|
func (group *RouterGroup) Use(middleware ...Handler) IRoutes { |
||||||
|
for _, m := range middleware { |
||||||
|
group.Handlers = append(group.Handlers, m.ServeHTTP) |
||||||
|
} |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// UseFunc adds middleware to the group, see example code in doc.
|
||||||
|
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||||
|
group.Handlers = append(group.Handlers, middleware...) |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
|
||||||
|
// For example, all the routes that use a common middlware for authorization could be grouped.
|
||||||
|
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { |
||||||
|
return &RouterGroup{ |
||||||
|
Handlers: group.combineHandlers(handlers), |
||||||
|
basePath: group.calculateAbsolutePath(relativePath), |
||||||
|
engine: group.engine, |
||||||
|
root: false, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// SetMethodConfig is used to set config on specified method
|
||||||
|
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup { |
||||||
|
group.baseConfig = config |
||||||
|
return group |
||||||
|
} |
||||||
|
|
||||||
|
// BasePath router group base path.
|
||||||
|
func (group *RouterGroup) BasePath() string { |
||||||
|
return group.basePath |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
absolutePath := group.calculateAbsolutePath(relativePath) |
||||||
|
injections := group.injections(relativePath) |
||||||
|
handlers = group.combineHandlers(injections, handlers) |
||||||
|
group.engine.addRoute(httpMethod, absolutePath, handlers...) |
||||||
|
if group.baseConfig != nil { |
||||||
|
group.engine.SetMethodConfig(absolutePath, group.baseConfig) |
||||||
|
} |
||||||
|
return group.returnObj() |
||||||
|
} |
||||||
|
|
||||||
|
// Handle registers a new request handle and middleware with the given path and method.
|
||||||
|
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
|
||||||
|
// See the example code in doc.
|
||||||
|
//
|
||||||
|
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
|
||||||
|
// functions can be used.
|
||||||
|
//
|
||||||
|
// This function is intended for bulk loading and to allow the usage of less
|
||||||
|
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||||
|
// communication with a proxy).
|
||||||
|
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil { |
||||||
|
panic("http method " + httpMethod + " is not valid") |
||||||
|
} |
||||||
|
return group.handle(httpMethod, relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
|
||||||
|
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("HEAD", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// GET is a shortcut for router.Handle("GET", path, handle).
|
||||||
|
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("GET", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// POST is a shortcut for router.Handle("POST", path, handle).
|
||||||
|
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("POST", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// PUT is a shortcut for router.Handle("PUT", path, handle).
|
||||||
|
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("PUT", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
|
||||||
|
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { |
||||||
|
return group.handle("DELETE", relativePath, handlers...) |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc { |
||||||
|
finalSize := len(group.Handlers) |
||||||
|
for _, handlers := range handlerGroups { |
||||||
|
finalSize += len(handlers) |
||||||
|
} |
||||||
|
if finalSize >= int(_abortIndex) { |
||||||
|
panic("too many handlers") |
||||||
|
} |
||||||
|
mergedHandlers := make([]HandlerFunc, finalSize) |
||||||
|
copy(mergedHandlers, group.Handlers) |
||||||
|
position := len(group.Handlers) |
||||||
|
for _, handlers := range handlerGroups { |
||||||
|
copy(mergedHandlers[position:], handlers) |
||||||
|
position += len(handlers) |
||||||
|
} |
||||||
|
return mergedHandlers |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { |
||||||
|
return joinPaths(group.basePath, relativePath) |
||||||
|
} |
||||||
|
|
||||||
|
func (group *RouterGroup) returnObj() IRoutes { |
||||||
|
if group.root { |
||||||
|
return group.engine |
||||||
|
} |
||||||
|
return group |
||||||
|
} |
||||||
|
|
||||||
|
// injections is
|
||||||
|
func (group *RouterGroup) injections(relativePath string) []HandlerFunc { |
||||||
|
absPath := group.calculateAbsolutePath(relativePath) |
||||||
|
for _, injection := range group.engine.injections { |
||||||
|
if !injection.pattern.MatchString(absPath) { |
||||||
|
continue |
||||||
|
} |
||||||
|
return injection.handlers |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,445 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"flag" |
||||||
|
"fmt" |
||||||
|
"net" |
||||||
|
"net/http" |
||||||
|
"os" |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/bilibili/Kratos/pkg/conf/dsn" |
||||||
|
"github.com/bilibili/Kratos/pkg/log" |
||||||
|
"github.com/bilibili/Kratos/pkg/net/ip" |
||||||
|
"github.com/bilibili/Kratos/pkg/net/metadata" |
||||||
|
"github.com/bilibili/Kratos/pkg/stat" |
||||||
|
xtime "github.com/bilibili/Kratos/pkg/time" |
||||||
|
|
||||||
|
"github.com/pkg/errors" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
defaultMaxMemory = 32 << 20 // 32 MB
|
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
_ IRouter = &Engine{} |
||||||
|
stats = stat.HTTPServer |
||||||
|
|
||||||
|
_httpDSN string |
||||||
|
) |
||||||
|
|
||||||
|
func init() { |
||||||
|
addFlag(flag.CommandLine) |
||||||
|
} |
||||||
|
|
||||||
|
func addFlag(fs *flag.FlagSet) { |
||||||
|
v := os.Getenv("HTTP") |
||||||
|
if v == "" { |
||||||
|
v = "tcp://0.0.0.0:8000/?timeout=1s" |
||||||
|
} |
||||||
|
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.") |
||||||
|
} |
||||||
|
|
||||||
|
func parseDSN(rawdsn string) *ServerConfig { |
||||||
|
conf := new(ServerConfig) |
||||||
|
d, err := dsn.Parse(rawdsn) |
||||||
|
if err != nil { |
||||||
|
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||||
|
} |
||||||
|
if _, err = d.Bind(conf); err != nil { |
||||||
|
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn)) |
||||||
|
} |
||||||
|
return conf |
||||||
|
} |
||||||
|
|
||||||
|
// Handler responds to an HTTP request.
|
||||||
|
type Handler interface { |
||||||
|
ServeHTTP(c *Context) |
||||||
|
} |
||||||
|
|
||||||
|
// HandlerFunc http request handler function.
|
||||||
|
type HandlerFunc func(*Context) |
||||||
|
|
||||||
|
// ServeHTTP calls f(ctx).
|
||||||
|
func (f HandlerFunc) ServeHTTP(c *Context) { |
||||||
|
f(c) |
||||||
|
} |
||||||
|
|
||||||
|
// ServerConfig is the bm server config model
|
||||||
|
type ServerConfig struct { |
||||||
|
Network string `dsn:"network"` |
||||||
|
// FIXME: rename to Address
|
||||||
|
Addr string `dsn:"address"` |
||||||
|
Timeout xtime.Duration `dsn:"query.timeout"` |
||||||
|
ReadTimeout xtime.Duration `dsn:"query.readTimeout"` |
||||||
|
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"` |
||||||
|
} |
||||||
|
|
||||||
|
// MethodConfig is
|
||||||
|
type MethodConfig struct { |
||||||
|
Timeout xtime.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// Start listen and serve bm engine by given DSN.
|
||||||
|
func (engine *Engine) Start() error { |
||||||
|
conf := engine.conf |
||||||
|
l, err := net.Listen(conf.Network, conf.Addr) |
||||||
|
if err != nil { |
||||||
|
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Info("blademaster: start http listen addr: %s", conf.Addr) |
||||||
|
server := &http.Server{ |
||||||
|
ReadTimeout: time.Duration(conf.ReadTimeout), |
||||||
|
WriteTimeout: time.Duration(conf.WriteTimeout), |
||||||
|
} |
||||||
|
go func() { |
||||||
|
if err := engine.RunServer(server, l); err != nil { |
||||||
|
if errors.Cause(err) == http.ErrServerClosed { |
||||||
|
log.Info("blademaster: server closed") |
||||||
|
return |
||||||
|
} |
||||||
|
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l)) |
||||||
|
} |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
|
||||||
|
// Create an instance of Engine, by using New() or Default()
|
||||||
|
type Engine struct { |
||||||
|
RouterGroup |
||||||
|
|
||||||
|
lock sync.RWMutex |
||||||
|
conf *ServerConfig |
||||||
|
|
||||||
|
address string |
||||||
|
|
||||||
|
mux *http.ServeMux // http mux router
|
||||||
|
server atomic.Value // store *http.Server
|
||||||
|
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
|
||||||
|
|
||||||
|
pcLock sync.RWMutex |
||||||
|
methodConfigs map[string]*MethodConfig |
||||||
|
|
||||||
|
injections []injection |
||||||
|
} |
||||||
|
|
||||||
|
type injection struct { |
||||||
|
pattern *regexp.Regexp |
||||||
|
handlers []HandlerFunc |
||||||
|
} |
||||||
|
|
||||||
|
// New returns a new blank Engine instance without any middleware attached.
|
||||||
|
//
|
||||||
|
// Deprecated: please use NewServer.
|
||||||
|
func New() *Engine { |
||||||
|
engine := &Engine{ |
||||||
|
RouterGroup: RouterGroup{ |
||||||
|
Handlers: nil, |
||||||
|
basePath: "/", |
||||||
|
root: true, |
||||||
|
}, |
||||||
|
address: ip.InternalIP(), |
||||||
|
conf: &ServerConfig{ |
||||||
|
Timeout: xtime.Duration(time.Second), |
||||||
|
}, |
||||||
|
mux: http.NewServeMux(), |
||||||
|
metastore: make(map[string]map[string]interface{}), |
||||||
|
methodConfigs: make(map[string]*MethodConfig), |
||||||
|
injections: make([]injection, 0), |
||||||
|
} |
||||||
|
engine.RouterGroup.engine = engine |
||||||
|
// NOTE add prometheus monitor location
|
||||||
|
engine.addRoute("GET", "/metrics", monitor()) |
||||||
|
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||||
|
startPerf() |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// NewServer returns a new blank Engine instance without any middleware attached.
|
||||||
|
func NewServer(conf *ServerConfig) *Engine { |
||||||
|
if conf == nil { |
||||||
|
if !flag.Parsed() { |
||||||
|
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n") |
||||||
|
} |
||||||
|
conf = parseDSN(_httpDSN) |
||||||
|
} else { |
||||||
|
fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n") |
||||||
|
} |
||||||
|
|
||||||
|
engine := &Engine{ |
||||||
|
RouterGroup: RouterGroup{ |
||||||
|
Handlers: nil, |
||||||
|
basePath: "/", |
||||||
|
root: true, |
||||||
|
}, |
||||||
|
address: ip.InternalIP(), |
||||||
|
mux: http.NewServeMux(), |
||||||
|
metastore: make(map[string]map[string]interface{}), |
||||||
|
methodConfigs: make(map[string]*MethodConfig), |
||||||
|
} |
||||||
|
if err := engine.SetConfig(conf); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
engine.RouterGroup.engine = engine |
||||||
|
// NOTE add prometheus monitor location
|
||||||
|
engine.addRoute("GET", "/metrics", monitor()) |
||||||
|
engine.addRoute("GET", "/metadata", engine.metadata()) |
||||||
|
startPerf() |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// SetMethodConfig is used to set config on specified path
|
||||||
|
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) { |
||||||
|
engine.pcLock.Lock() |
||||||
|
engine.methodConfigs[path] = mc |
||||||
|
engine.pcLock.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||||
|
func DefaultServer(conf *ServerConfig) *Engine { |
||||||
|
engine := NewServer(conf) |
||||||
|
engine.Use(Recovery(), Trace(), Logger()) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||||
|
//
|
||||||
|
// Deprecated: please use DefaultServer.
|
||||||
|
func Default() *Engine { |
||||||
|
engine := New() |
||||||
|
engine.Use(Recovery(), Trace(), Logger()) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) { |
||||||
|
if path[0] != '/' { |
||||||
|
panic("blademaster: path must begin with '/'") |
||||||
|
} |
||||||
|
if method == "" { |
||||||
|
panic("blademaster: HTTP method can not be empty") |
||||||
|
} |
||||||
|
if len(handlers) == 0 { |
||||||
|
panic("blademaster: there must be at least one handler") |
||||||
|
} |
||||||
|
if _, ok := engine.metastore[path]; !ok { |
||||||
|
engine.metastore[path] = make(map[string]interface{}) |
||||||
|
} |
||||||
|
engine.metastore[path]["method"] = method |
||||||
|
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { |
||||||
|
c := &Context{ |
||||||
|
Context: nil, |
||||||
|
engine: engine, |
||||||
|
index: -1, |
||||||
|
handlers: nil, |
||||||
|
Keys: nil, |
||||||
|
method: "", |
||||||
|
Error: nil, |
||||||
|
} |
||||||
|
|
||||||
|
c.Request = req |
||||||
|
c.Writer = w |
||||||
|
c.handlers = handlers |
||||||
|
c.method = method |
||||||
|
|
||||||
|
engine.handleContext(c) |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
// SetConfig is used to set the engine configuration.
|
||||||
|
// Only the valid config will be loaded.
|
||||||
|
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) { |
||||||
|
if conf.Timeout <= 0 { |
||||||
|
return errors.New("blademaster: config timeout must greater than 0") |
||||||
|
} |
||||||
|
if conf.Network == "" { |
||||||
|
conf.Network = "tcp" |
||||||
|
} |
||||||
|
engine.lock.Lock() |
||||||
|
engine.conf = conf |
||||||
|
engine.lock.Unlock() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) methodConfig(path string) *MethodConfig { |
||||||
|
engine.pcLock.RLock() |
||||||
|
mc := engine.methodConfigs[path] |
||||||
|
engine.pcLock.RUnlock() |
||||||
|
return mc |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) handleContext(c *Context) { |
||||||
|
var cancel func() |
||||||
|
req := c.Request |
||||||
|
ctype := req.Header.Get("Content-Type") |
||||||
|
switch { |
||||||
|
case strings.Contains(ctype, "multipart/form-data"): |
||||||
|
req.ParseMultipartForm(defaultMaxMemory) |
||||||
|
default: |
||||||
|
req.ParseForm() |
||||||
|
} |
||||||
|
// get derived timeout from http request header,
|
||||||
|
// compare with the engine configured,
|
||||||
|
// and use the minimum one
|
||||||
|
engine.lock.RLock() |
||||||
|
tm := time.Duration(engine.conf.Timeout) |
||||||
|
engine.lock.RUnlock() |
||||||
|
// the method config is preferred
|
||||||
|
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil { |
||||||
|
tm = time.Duration(pc.Timeout) |
||||||
|
} |
||||||
|
if ctm := timeout(req); ctm > 0 && tm > ctm { |
||||||
|
tm = ctm |
||||||
|
} |
||||||
|
md := metadata.MD{ |
||||||
|
metadata.Color: color(req), |
||||||
|
metadata.RemoteIP: remoteIP(req), |
||||||
|
metadata.RemotePort: remotePort(req), |
||||||
|
metadata.Caller: caller(req), |
||||||
|
metadata.Mirror: mirror(req), |
||||||
|
} |
||||||
|
ctx := metadata.NewContext(context.Background(), md) |
||||||
|
if tm > 0 { |
||||||
|
c.Context, cancel = context.WithTimeout(ctx, tm) |
||||||
|
} else { |
||||||
|
c.Context, cancel = context.WithCancel(ctx) |
||||||
|
} |
||||||
|
defer cancel() |
||||||
|
c.Next() |
||||||
|
} |
||||||
|
|
||||||
|
// Router return a http.Handler for using http.ListenAndServe() directly.
|
||||||
|
func (engine *Engine) Router() http.Handler { |
||||||
|
return engine.mux |
||||||
|
} |
||||||
|
|
||||||
|
// Server is used to load stored http server.
|
||||||
|
func (engine *Engine) Server() *http.Server { |
||||||
|
s, ok := engine.server.Load().(*http.Server) |
||||||
|
if !ok { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return s |
||||||
|
} |
||||||
|
|
||||||
|
// Shutdown the http server without interrupting active connections.
|
||||||
|
func (engine *Engine) Shutdown(ctx context.Context) error { |
||||||
|
server := engine.Server() |
||||||
|
if server == nil { |
||||||
|
return errors.New("blademaster: no server") |
||||||
|
} |
||||||
|
return errors.WithStack(server.Shutdown(ctx)) |
||||||
|
} |
||||||
|
|
||||||
|
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
|
||||||
|
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||||
|
// For example, this is the right place for a logger or error management middleware.
|
||||||
|
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes { |
||||||
|
engine.RouterGroup.UseFunc(middleware...) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
|
||||||
|
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||||
|
// For example, this is the right place for a logger or error management middleware.
|
||||||
|
func (engine *Engine) Use(middleware ...Handler) IRoutes { |
||||||
|
engine.RouterGroup.Use(middleware...) |
||||||
|
return engine |
||||||
|
} |
||||||
|
|
||||||
|
// Ping is used to set the general HTTP ping handler.
|
||||||
|
func (engine *Engine) Ping(handler HandlerFunc) { |
||||||
|
engine.GET("/monitor/ping", handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Register is used to export metadata to discovery.
|
||||||
|
func (engine *Engine) Register(handler HandlerFunc) { |
||||||
|
engine.GET("/register", handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
|
||||||
|
// It is a shortcut for http.ListenAndServe(addr, router)
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) Run(addr ...string) (err error) { |
||||||
|
address := resolveAddress(addr) |
||||||
|
server := &http.Server{ |
||||||
|
Addr: address, |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.ListenAndServe(); err != nil { |
||||||
|
err = errors.Wrapf(err, "addrs: %v", addr) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
|
||||||
|
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { |
||||||
|
server := &http.Server{ |
||||||
|
Addr: addr, |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil { |
||||||
|
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
|
||||||
|
// through the specified unix socket (ie. a file).
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunUnix(file string) (err error) { |
||||||
|
os.Remove(file) |
||||||
|
listener, err := net.Listen("unix", file) |
||||||
|
if err != nil { |
||||||
|
err = errors.Wrapf(err, "unix: %s", file) |
||||||
|
return |
||||||
|
} |
||||||
|
defer listener.Close() |
||||||
|
server := &http.Server{ |
||||||
|
Handler: engine.mux, |
||||||
|
} |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.Serve(listener); err != nil { |
||||||
|
err = errors.Wrapf(err, "unix: %s", file) |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RunServer will serve and start listening HTTP requests by given server and listener.
|
||||||
|
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||||
|
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) { |
||||||
|
server.Handler = engine.mux |
||||||
|
engine.server.Store(server) |
||||||
|
if err = server.Serve(l); err != nil { |
||||||
|
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l) |
||||||
|
return |
||||||
|
} |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func (engine *Engine) metadata() HandlerFunc { |
||||||
|
return func(c *Context) { |
||||||
|
c.JSON(engine.metastore, nil) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Inject is
|
||||||
|
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) { |
||||||
|
engine.injections = append(engine.injections, injection{ |
||||||
|
pattern: regexp.MustCompile(pattern), |
||||||
|
handlers: handlers, |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package blademaster |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"path" |
||||||
|
) |
||||||
|
|
||||||
|
func lastChar(str string) uint8 { |
||||||
|
if str == "" { |
||||||
|
panic("The length of the string can't be 0") |
||||||
|
} |
||||||
|
return str[len(str)-1] |
||||||
|
} |
||||||
|
|
||||||
|
func joinPaths(absolutePath, relativePath string) string { |
||||||
|
if relativePath == "" { |
||||||
|
return absolutePath |
||||||
|
} |
||||||
|
|
||||||
|
finalPath := path.Join(absolutePath, relativePath) |
||||||
|
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' |
||||||
|
if appendSlash { |
||||||
|
return finalPath + "/" |
||||||
|
} |
||||||
|
return finalPath |
||||||
|
} |
||||||
|
|
||||||
|
func resolveAddress(addr []string) string { |
||||||
|
switch len(addr) { |
||||||
|
case 0: |
||||||
|
if port := os.Getenv("PORT"); port != "" { |
||||||
|
//debugPrint("Environment variable PORT=\"%s\"", port)
|
||||||
|
return ":" + port |
||||||
|
} |
||||||
|
//debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
|
||||||
|
return ":8080" |
||||||
|
case 1: |
||||||
|
return addr[0] |
||||||
|
default: |
||||||
|
panic("too much parameters") |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue