package http import ( "bytes" "context" "crypto/tls" "fmt" "io" "net/http" "time" "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/internal/httputil" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/registry" "github.com/go-kratos/kratos/v2/selector" "github.com/go-kratos/kratos/v2/selector/wrr" "github.com/go-kratos/kratos/v2/transport" ) func init() { if selector.GlobalSelector() == nil { selector.SetGlobalSelector(wrr.NewBuilder()) } } // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error // EncodeRequestFunc is request encode func. type EncodeRequestFunc func(ctx context.Context, contentType string, in interface{}) (body []byte, err error) // DecodeResponseFunc is response decode func. type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interface{}) error // ClientOption is HTTP client option. type ClientOption func(*clientOptions) // Client is an HTTP transport client. type clientOptions struct { ctx context.Context tlsConf *tls.Config timeout time.Duration endpoint string userAgent string encoder EncodeRequestFunc decoder DecodeResponseFunc errorDecoder DecodeErrorFunc transport http.RoundTripper nodeFilters []selector.NodeFilter discovery registry.Discovery middleware []middleware.Middleware block bool } // WithTransport with client transport. func WithTransport(trans http.RoundTripper) ClientOption { return func(o *clientOptions) { o.transport = trans } } // WithTimeout with client request timeout. func WithTimeout(d time.Duration) ClientOption { return func(o *clientOptions) { o.timeout = d } } // WithUserAgent with client user agent. func WithUserAgent(ua string) ClientOption { return func(o *clientOptions) { o.userAgent = ua } } // WithMiddleware with client middleware. func WithMiddleware(m ...middleware.Middleware) ClientOption { return func(o *clientOptions) { o.middleware = m } } // WithEndpoint with client addr. func WithEndpoint(endpoint string) ClientOption { return func(o *clientOptions) { o.endpoint = endpoint } } // WithRequestEncoder with client request encoder. func WithRequestEncoder(encoder EncodeRequestFunc) ClientOption { return func(o *clientOptions) { o.encoder = encoder } } // WithResponseDecoder with client response decoder. func WithResponseDecoder(decoder DecodeResponseFunc) ClientOption { return func(o *clientOptions) { o.decoder = decoder } } // WithErrorDecoder with client error decoder. func WithErrorDecoder(errorDecoder DecodeErrorFunc) ClientOption { return func(o *clientOptions) { o.errorDecoder = errorDecoder } } // WithDiscovery with client discovery. func WithDiscovery(d registry.Discovery) ClientOption { return func(o *clientOptions) { o.discovery = d } } // WithNodeFilter with select filters func WithNodeFilter(filters ...selector.NodeFilter) ClientOption { return func(o *clientOptions) { o.nodeFilters = filters } } // WithBlock with client block. func WithBlock() ClientOption { return func(o *clientOptions) { o.block = true } } // WithTLSConfig with tls config. func WithTLSConfig(c *tls.Config) ClientOption { return func(o *clientOptions) { o.tlsConf = c } } // Client is an HTTP client. type Client struct { opts clientOptions target *Target r *resolver cc *http.Client insecure bool selector selector.Selector } // NewClient returns an HTTP client. func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { options := clientOptions{ ctx: ctx, timeout: 2000 * time.Millisecond, encoder: DefaultRequestEncoder, decoder: DefaultResponseDecoder, errorDecoder: DefaultErrorDecoder, transport: http.DefaultTransport, } for _, o := range opts { o(&options) } if options.tlsConf != nil { if tr, ok := options.transport.(*http.Transport); ok { tr.TLSClientConfig = options.tlsConf } } insecure := options.tlsConf == nil target, err := parseTarget(options.endpoint, insecure) if err != nil { return nil, err } selector := selector.GlobalSelector().Build() var r *resolver if options.discovery != nil { if target.Scheme == "discovery" { if r, err = newResolver(ctx, options.discovery, target, selector, options.block, insecure); err != nil { return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) } } else if _, _, err := host.ExtractHostPort(options.endpoint); err != nil { return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) } } return &Client{ opts: options, target: target, insecure: insecure, r: r, cc: &http.Client{ Timeout: options.timeout, Transport: options.transport, }, selector: selector, }, nil } // Invoke makes a rpc call procedure for remote service. func (client *Client) Invoke(ctx context.Context, method, path string, args interface{}, reply interface{}, opts ...CallOption) error { var ( contentType string body io.Reader ) c := defaultCallInfo(path) for _, o := range opts { if err := o.before(&c); err != nil { return err } } if args != nil { data, err := client.opts.encoder(ctx, c.contentType, args) if err != nil { return err } contentType = c.contentType body = bytes.NewReader(data) } url := fmt.Sprintf("%s://%s%s", client.target.Scheme, client.target.Authority, path) req, err := http.NewRequest(method, url, body) if err != nil { return err } if contentType != "" { req.Header.Set("Content-Type", c.contentType) } if client.opts.userAgent != "" { req.Header.Set("User-Agent", client.opts.userAgent) } ctx = transport.NewClientContext(ctx, &Transport{ endpoint: client.opts.endpoint, reqHeader: headerCarrier(req.Header), operation: c.operation, request: req, pathTemplate: c.pathTemplate, }) return client.invoke(ctx, req, args, reply, c, opts...) } func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo, opts ...CallOption) error { h := func(ctx context.Context, in interface{}) (interface{}, error) { res, err := client.do(req.WithContext(ctx)) if res != nil { cs := csAttempt{res: res} for _, o := range opts { o.after(&c, &cs) } } if err != nil { return nil, err } defer res.Body.Close() if err := client.opts.decoder(ctx, res, reply); err != nil { return nil, err } return reply, nil } var p selector.Peer ctx = selector.NewPeerContext(ctx, &p) if len(client.opts.middleware) > 0 { h = middleware.Chain(client.opts.middleware...)(h) } _, err := h(ctx, args) return err } // Do send an HTTP request and decodes the body of response into target. // returns an error (of type *Error) if the response status code is not 2xx. func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, error) { c := defaultCallInfo(req.URL.Path) for _, o := range opts { if err := o.before(&c); err != nil { return nil, err } } return client.do(req) } func (client *Client) do(req *http.Request) (*http.Response, error) { var done func(context.Context, selector.DoneInfo) if client.r != nil { var ( err error node selector.Node ) if node, done, err = client.selector.Select(req.Context(), selector.WithNodeFilter(client.opts.nodeFilters...)); err != nil { return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } if client.insecure { req.URL.Scheme = "http" } else { req.URL.Scheme = "https" } req.URL.Host = node.Address() req.Host = node.Address() } resp, err := client.cc.Do(req) if err == nil { err = client.opts.errorDecoder(req.Context(), resp) } if done != nil { done(req.Context(), selector.DoneInfo{Err: err}) } if err != nil { return nil, err } return resp, nil } // Close tears down the Transport and all underlying connections. func (client *Client) Close() error { if client.r != nil { return client.r.Close() } return nil } // DefaultRequestEncoder is an HTTP request encoder. func DefaultRequestEncoder(ctx context.Context, contentType string, in interface{}) ([]byte, error) { name := httputil.ContentSubtype(contentType) body, err := encoding.GetCodec(name).Marshal(in) if err != nil { return nil, err } return body, err } // DefaultResponseDecoder is an HTTP response decoder. func DefaultResponseDecoder(ctx context.Context, res *http.Response, v interface{}) error { defer res.Body.Close() data, err := io.ReadAll(res.Body) if err != nil { return err } return CodecForResponse(res).Unmarshal(data, v) } // DefaultErrorDecoder is an HTTP error decoder. func DefaultErrorDecoder(ctx context.Context, res *http.Response) error { if res.StatusCode >= 200 && res.StatusCode <= 299 { return nil } defer res.Body.Close() data, err := io.ReadAll(res.Body) if err == nil { e := new(errors.Error) if err = CodecForResponse(res).Unmarshal(data, e); err == nil { e.Code = int32(res.StatusCode) return e } } return errors.Newf(res.StatusCode, errors.UnknownReason, "").WithCause(err) } // CodecForResponse get encoding.Codec via http.Response func CodecForResponse(r *http.Response) encoding.Codec { codec := encoding.GetCodec(httputil.ContentSubtype(r.Header.Get("Content-Type"))) if codec != nil { return codec } return encoding.GetCodec("json") }