You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
kratos/transport/http/client.go

363 lines
9.4 KiB

package http
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"time"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/internal/host"
"github.com/go-kratos/kratos/v2/internal/httputil"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/registry"
"github.com/go-kratos/kratos/v2/selector"
"github.com/go-kratos/kratos/v2/selector/wrr"
"github.com/go-kratos/kratos/v2/transport"
)
func init() {
if selector.GlobalSelector() == nil {
selector.SetGlobalSelector(wrr.NewBuilder())
}
}
// DecodeErrorFunc is decode error func.
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error
// EncodeRequestFunc is request encode func.
type EncodeRequestFunc func(ctx context.Context, contentType string, in interface{}) (body []byte, err error)
// DecodeResponseFunc is response decode func.
type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interface{}) error
// ClientOption is HTTP client option.
type ClientOption func(*clientOptions)
// Client is an HTTP transport client.
type clientOptions struct {
ctx context.Context
tlsConf *tls.Config
timeout time.Duration
endpoint string
userAgent string
encoder EncodeRequestFunc
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
transport http.RoundTripper
nodeFilters []selector.NodeFilter
discovery registry.Discovery
middleware []middleware.Middleware
block bool
}
// WithTransport with client transport.
func WithTransport(trans http.RoundTripper) ClientOption {
return func(o *clientOptions) {
o.transport = trans
}
}
// WithTimeout with client request timeout.
func WithTimeout(d time.Duration) ClientOption {
return func(o *clientOptions) {
o.timeout = d
}
}
// WithUserAgent with client user agent.
func WithUserAgent(ua string) ClientOption {
return func(o *clientOptions) {
o.userAgent = ua
}
}
// WithMiddleware with client middleware.
func WithMiddleware(m ...middleware.Middleware) ClientOption {
return func(o *clientOptions) {
o.middleware = m
}
}
// WithEndpoint with client addr.
func WithEndpoint(endpoint string) ClientOption {
return func(o *clientOptions) {
o.endpoint = endpoint
}
}
// WithRequestEncoder with client request encoder.
func WithRequestEncoder(encoder EncodeRequestFunc) ClientOption {
return func(o *clientOptions) {
o.encoder = encoder
}
}
// WithResponseDecoder with client response decoder.
func WithResponseDecoder(decoder DecodeResponseFunc) ClientOption {
return func(o *clientOptions) {
o.decoder = decoder
}
}
// WithErrorDecoder with client error decoder.
func WithErrorDecoder(errorDecoder DecodeErrorFunc) ClientOption {
return func(o *clientOptions) {
o.errorDecoder = errorDecoder
}
}
// WithDiscovery with client discovery.
func WithDiscovery(d registry.Discovery) ClientOption {
return func(o *clientOptions) {
o.discovery = d
}
}
// WithNodeFilter with select filters
func WithNodeFilter(filters ...selector.NodeFilter) ClientOption {
return func(o *clientOptions) {
o.nodeFilters = filters
}
}
// WithBlock with client block.
func WithBlock() ClientOption {
return func(o *clientOptions) {
o.block = true
}
}
// WithTLSConfig with tls config.
func WithTLSConfig(c *tls.Config) ClientOption {
return func(o *clientOptions) {
o.tlsConf = c
}
}
// Client is an HTTP client.
type Client struct {
opts clientOptions
target *Target
r *resolver
cc *http.Client
insecure bool
selector selector.Selector
}
// NewClient returns an HTTP client.
func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
options := clientOptions{
ctx: ctx,
timeout: 2000 * time.Millisecond,
encoder: DefaultRequestEncoder,
decoder: DefaultResponseDecoder,
errorDecoder: DefaultErrorDecoder,
transport: http.DefaultTransport,
}
for _, o := range opts {
o(&options)
}
if options.tlsConf != nil {
if tr, ok := options.transport.(*http.Transport); ok {
tr.TLSClientConfig = options.tlsConf
}
}
insecure := options.tlsConf == nil
target, err := parseTarget(options.endpoint, insecure)
if err != nil {
return nil, err
}
selector := selector.GlobalSelector().Build()
var r *resolver
if options.discovery != nil {
if target.Scheme == "discovery" {
if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil {
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
}
} else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil {
return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint)
}
}
return &Client{
opts: options,
target: target,
insecure: insecure,
r: r,
cc: &http.Client{
Timeout: options.timeout,
Transport: options.transport,
},
selector: selector,
}, nil
}
// Invoke makes a rpc call procedure for remote service.
func (client *Client) Invoke(ctx context.Context, method, path string, args interface{}, reply interface{}, opts ...CallOption) error {
var (
contentType string
body io.Reader
)
c := defaultCallInfo(path)
for _, o := range opts {
if err := o.before(&c); err != nil {
return err
}
}
if args != nil {
data, err := client.opts.encoder(ctx, c.contentType, args)
if err != nil {
return err
}
contentType = c.contentType
body = bytes.NewReader(data)
}
url := fmt.Sprintf("%s://%s%s", client.target.Scheme, client.target.Authority, path)
req, err := http.NewRequest(method, url, body)
if err != nil {
return err
}
if contentType != "" {
req.Header.Set("Content-Type", c.contentType)
}
if client.opts.userAgent != "" {
req.Header.Set("User-Agent", client.opts.userAgent)
}
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: client.opts.endpoint,
reqHeader: headerCarrier(req.Header),
operation: c.operation,
request: req,
pathTemplate: c.pathTemplate,
})
return client.invoke(ctx, req, args, reply, c, opts...)
}
func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo, opts ...CallOption) error {
h := func(ctx context.Context, in interface{}) (interface{}, error) {
res, err := client.do(req.WithContext(ctx))
if res != nil {
cs := csAttempt{res: res}
for _, o := range opts {
o.after(&c, &cs)
}
}
if err != nil {
return nil, err
}
defer res.Body.Close()
if err := client.opts.decoder(ctx, res, reply); err != nil {
return nil, err
}
return reply, nil
}
var p selector.Peer
ctx = selector.NewPeerContext(ctx, &p)
if len(client.opts.middleware) > 0 {
h = middleware.Chain(client.opts.middleware...)(h)
}
_, err := h(ctx, args)
return err
}
// Do send an HTTP request and decodes the body of response into target.
// returns an error (of type *Error) if the response status code is not 2xx.
func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) {
c := defaultCallInfo(req.URL.Path)
for _, o := range opts {
if err := o.before(&c); err != nil {
return nil, err
}
}
return client.do(req)
}
func (client *Client) do(req *http.Request) (*http.Response, error) {
var done func(context.Context, selector.DoneInfo)
if client.r != nil {
var (
err error
node selector.Node
)
if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
}
if client.insecure {
req.URL.Scheme = "http"
} else {
req.URL.Scheme = "https"
}
req.URL.Host = node.Address()
req.Host = node.Address()
}
resp, err := client.cc.Do(req)
if err == nil {
err = client.opts.errorDecoder(req.Context(), resp)
}
if done != nil {
done(req.Context(), selector.DoneInfo{Err: err})
}
if err != nil {
return nil, err
}
return resp, nil
}
// Close tears down the Transport and all underlying connections.
func (client *Client) Close() error {
if client.r != nil {
return client.r.Close()
}
return nil
}
// DefaultRequestEncoder is an HTTP request encoder.
func DefaultRequestEncoder(ctx context.Context, contentType string, in interface{}) ([]byte, error) {
name := httputil.ContentSubtype(contentType)
body, err := encoding.GetCodec(name).Marshal(in)
if err != nil {
return nil, err
}
return body, err
}
// DefaultResponseDecoder is an HTTP response decoder.
func DefaultResponseDecoder(ctx context.Context, res *http.Response, v interface{}) error {
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err != nil {
return err
}
return CodecForResponse(res).Unmarshal(data, v)
}
// DefaultErrorDecoder is an HTTP error decoder.
func DefaultErrorDecoder(ctx context.Context, res *http.Response) error {
if res.StatusCode >= 200 && res.StatusCode <= 299 {
return nil
}
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err == nil {
e := new(errors.Error)
if err = CodecForResponse(res).Unmarshal(data, e); err == nil {
e.Code = int32(res.StatusCode)
return e
}
}
return errors.Newf(res.StatusCode, errors.UnknownReason, "").WithCause(err)
}
// CodecForResponse get encoding.Codec via http.Response
func CodecForResponse(r *http.Response) encoding.Codec {
codec := encoding.GetCodec(httputil.ContentSubtype(r.Header.Get("Content-Type")))
if codec != nil {
return codec
}
return encoding.GetCodec("json")
}