You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
3.3 KiB
137 lines
3.3 KiB
4 years ago
|
package http
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-kratos/kratos/v2/encoding"
|
||
|
"github.com/go-kratos/kratos/v2/errors"
|
||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||
|
"github.com/go-kratos/kratos/v2/transport"
|
||
|
)
|
||
|
|
||
|
// ClientOption is HTTP client option.
|
||
|
type ClientOption func(*clientOptions)
|
||
|
|
||
|
// WithTimeout with client request timeout.
|
||
|
func WithTimeout(d time.Duration) ClientOption {
|
||
|
return func(o *clientOptions) {
|
||
|
o.timeout = d
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithUserAgent with client user agent.
|
||
|
func WithUserAgent(ua string) ClientOption {
|
||
|
return func(o *clientOptions) {
|
||
|
o.userAgent = ua
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithTransport with client transport.
|
||
|
func WithTransport(trans http.RoundTripper) ClientOption {
|
||
|
return func(o *clientOptions) {
|
||
|
o.transport = trans
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithMiddleware with client middleware.
|
||
|
func WithMiddleware(m middleware.Middleware) ClientOption {
|
||
|
return func(o *clientOptions) {
|
||
|
o.middleware = m
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Client is a HTTP transport client.
|
||
|
type clientOptions struct {
|
||
|
ctx context.Context
|
||
|
timeout time.Duration
|
||
|
userAgent string
|
||
|
transport http.RoundTripper
|
||
|
middleware middleware.Middleware
|
||
|
}
|
||
|
|
||
|
// NewClient returns an HTTP client.
|
||
|
func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error) {
|
||
|
trans, err := NewTransport(ctx, opts...)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &http.Client{Transport: trans}, nil
|
||
|
}
|
||
|
|
||
|
// NewTransport creates an http.RoundTripper.
|
||
|
func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) {
|
||
|
options := &clientOptions{
|
||
|
ctx: ctx,
|
||
|
timeout: 500 * time.Millisecond,
|
||
|
transport: http.DefaultTransport,
|
||
|
}
|
||
|
for _, o := range opts {
|
||
|
o(options)
|
||
|
}
|
||
|
return &baseTransport{
|
||
|
middleware: options.middleware,
|
||
|
userAgent: options.userAgent,
|
||
|
timeout: options.timeout,
|
||
|
base: options.transport,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
type baseTransport struct {
|
||
|
userAgent string
|
||
|
timeout time.Duration
|
||
|
base http.RoundTripper
|
||
|
middleware middleware.Middleware
|
||
|
}
|
||
|
|
||
|
func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
if t.userAgent != "" && req.Header.Get("User-Agent") == "" {
|
||
|
req.Header.Set("User-Agent", t.userAgent)
|
||
|
}
|
||
|
ctx := transport.NewContext(req.Context(), transport.Transport{Kind: "HTTP"})
|
||
|
ctx = NewClientContext(ctx, ClientInfo{Request: req})
|
||
|
ctx, cancel := context.WithTimeout(ctx, t.timeout)
|
||
|
defer cancel()
|
||
|
|
||
|
h := func(ctx context.Context, in interface{}) (interface{}, error) {
|
||
|
return t.base.RoundTrip(req)
|
||
|
}
|
||
|
if t.middleware != nil {
|
||
|
h = t.middleware(h)
|
||
|
}
|
||
|
res, err := h(ctx, req)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return res.(*http.Response), nil
|
||
|
}
|
||
|
|
||
|
// Do send an HTTP request and decodes the body of response into target.
|
||
|
// returns an error (of type *Error) if the response status code is not 2xx.
|
||
|
func Do(client *http.Client, req *http.Request, target interface{}) error {
|
||
|
res, err := client.Do(req)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
data, err := ioutil.ReadAll(res.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer res.Body.Close()
|
||
|
subtype := contentSubtype(res.Header.Get("content-type"))
|
||
|
codec := encoding.GetCodec(subtype)
|
||
|
if codec == nil {
|
||
|
codec = encoding.GetCodec("json")
|
||
|
}
|
||
|
if res.StatusCode < 200 || res.StatusCode > 299 {
|
||
|
se := &errors.StatusError{}
|
||
|
if err := codec.Unmarshal(data, se); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return se
|
||
|
}
|
||
|
return codec.Unmarshal(data, target)
|
||
|
}
|