transport/http: clean http client (#981)

* clean http client
pull/985/head
Tony Chen 4 years ago committed by GitHub
parent 079f11fb50
commit bef6d8d818
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      examples/helloworld/client/main.go
  2. 1
      examples/registry/consul/client/main.go
  3. 3
      transport/http/balancer/random/random.go
  4. 145
      transport/http/client.go
  5. 31
      transport/http/resolver.go

@ -24,7 +24,6 @@ func callHTTP() {
recovery.Recovery(), recovery.Recovery(),
), ),
transhttp.WithEndpoint("127.0.0.1:8000"), transhttp.WithEndpoint("127.0.0.1:8000"),
transhttp.WithScheme("http"),
) )
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

@ -47,7 +47,6 @@ func callHTTP(cli *api.Client) {
transhttp.WithMiddleware( transhttp.WithMiddleware(
recovery.Recovery(), recovery.Recovery(),
), ),
transhttp.WithScheme("http"),
transhttp.WithEndpoint("discovery:///helloworld"), transhttp.WithEndpoint("discovery:///helloworld"),
transhttp.WithDiscovery(r), transhttp.WithDiscovery(r),
) )

@ -21,7 +21,8 @@ func New() *Balancer {
func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(context.Context, balancer.DoneInfo), err error) { func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(context.Context, balancer.DoneInfo), err error) {
if len(nodes) == 0 { if len(nodes) == 0 {
return nil, nil, fmt.Errorf("no instances avaiable") return nil, nil, fmt.Errorf("no instances avaiable")
} else if len(nodes) == 1 { }
if len(nodes) == 1 {
return nodes[0], func(context.Context, balancer.DoneInfo) {}, nil return nodes[0], func(context.Context, balancer.DoneInfo) {}, nil
} }
idx := rand.Intn(len(nodes)) idx := rand.Intn(len(nodes))

@ -7,7 +7,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding"
@ -20,23 +19,6 @@ import (
"github.com/go-kratos/kratos/v2/transport/http/balancer/random" "github.com/go-kratos/kratos/v2/transport/http/balancer/random"
) )
// Client is an HTTP client.
type Client struct {
cc *http.Client
r *resolver
b balancer.Balancer
scheme string
endpoint string
target Target
userAgent string
middleware middleware.Middleware
encoder EncodeRequestFunc
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
discovery registry.Discovery
}
// DecodeErrorFunc is decode error func. // DecodeErrorFunc is decode error func.
type DecodeErrorFunc func(ctx context.Context, res *http.Response) error type DecodeErrorFunc func(ctx context.Context, res *http.Response) error
@ -49,6 +31,21 @@ type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interf
// ClientOption is HTTP client option. // ClientOption is HTTP client option.
type ClientOption func(*clientOptions) type ClientOption func(*clientOptions)
// Client is an HTTP transport client.
type clientOptions struct {
ctx context.Context
timeout time.Duration
endpoint string
userAgent string
encoder EncodeRequestFunc
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
transport http.RoundTripper
balancer balancer.Balancer
discovery registry.Discovery
middleware middleware.Middleware
}
// WithTransport with client transport. // WithTransport with client transport.
func WithTransport(trans http.RoundTripper) ClientOption { func WithTransport(trans http.RoundTripper) ClientOption {
return func(o *clientOptions) { return func(o *clientOptions) {
@ -77,13 +74,6 @@ func WithMiddleware(m ...middleware.Middleware) ClientOption {
} }
} }
// WithScheme with client schema.
func WithScheme(scheme string) ClientOption {
return func(o *clientOptions) {
o.scheme = scheme
}
}
// WithEndpoint with client addr. // WithEndpoint with client addr.
func WithEndpoint(endpoint string) ClientOption { func WithEndpoint(endpoint string) ClientOption {
return func(o *clientOptions) { return func(o *clientOptions) {
@ -128,27 +118,18 @@ func WithBalancer(b balancer.Balancer) ClientOption {
} }
} }
// Client is an HTTP transport client. // Client is an HTTP client.
type clientOptions struct { type Client struct {
ctx context.Context opts clientOptions
transport http.RoundTripper target *Target
middleware middleware.Middleware r *resolver
timeout time.Duration cc *http.Client
scheme string
endpoint string
userAgent string
encoder EncodeRequestFunc
decoder DecodeResponseFunc
errorDecoder DecodeErrorFunc
discovery registry.Discovery
balancer balancer.Balancer
} }
// NewClient returns an HTTP client. // NewClient returns an HTTP client.
func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
options := &clientOptions{ options := clientOptions{
ctx: ctx, ctx: ctx,
scheme: "http",
timeout: 500 * time.Millisecond, timeout: 500 * time.Millisecond,
encoder: DefaultRequestEncoder, encoder: DefaultRequestEncoder,
decoder: DefaultResponseDecoder, decoder: DefaultResponseDecoder,
@ -157,49 +138,30 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) {
balancer: random.New(), balancer: random.New(),
} }
for _, o := range opts { for _, o := range opts {
o(options) o(&options)
} }
target := Target{ target, err := parseTarget(options.endpoint)
Scheme: options.scheme,
Endpoint: options.endpoint,
}
var r *resolver
if options.endpoint != "" && options.discovery != nil {
u, err := url.Parse(options.endpoint)
if err != nil { if err != nil {
u, err = url.Parse("http://" + options.endpoint) return nil, err
if err != nil {
return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint)
}
}
if u.Scheme == "discovery" && len(u.Path) > 1 {
target = Target{
Scheme: u.Scheme,
Authority: u.Host,
Endpoint: u.Path[1:],
} }
r, err = newResolver(ctx, options.scheme, options.discovery, target) var r *resolver
if err != nil { if target.Endpoint != "" && options.discovery != nil {
if target.Scheme == "discovery" {
if r, err = newResolver(ctx, options.discovery, target); err != nil {
return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint)
} }
} else { } else {
return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint)
} }
} }
return &Client{ return &Client{
cc: &http.Client{Timeout: options.timeout, Transport: options.transport}, opts: options,
r: r,
encoder: options.encoder,
decoder: options.decoder,
errorDecoder: options.errorDecoder,
middleware: options.middleware,
userAgent: options.userAgent,
target: target, target: target,
scheme: options.scheme, r: r,
endpoint: options.endpoint, cc: &http.Client{
discovery: options.discovery, Timeout: options.timeout,
b: options.balancer, Transport: options.transport,
},
}, nil }, nil
} }
@ -220,13 +182,13 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{},
body []byte body []byte
err error err error
) )
contentType, body, err = client.encoder(ctx, args) contentType, body, err = client.opts.encoder(ctx, args)
if err != nil { if err != nil {
return err return err
} }
reqBody = bytes.NewReader(body) reqBody = bytes.NewReader(body)
} }
url := fmt.Sprintf("%s://%s%s", client.scheme, client.target.Endpoint, path) url := fmt.Sprintf("%s://%s%s", client.target.Scheme, client.target.Authority, path)
req, err := http.NewRequest(c.method, url, reqBody) req, err := http.NewRequest(c.method, url, reqBody)
if err != nil { if err != nil {
return err return err
@ -234,10 +196,10 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{},
if contentType != "" { if contentType != "" {
req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Type", contentType)
} }
if client.userAgent != "" { if client.opts.userAgent != "" {
req.Header.Set("User-Agent", client.userAgent) req.Header.Set("User-Agent", client.opts.userAgent)
} }
ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP, Endpoint: client.endpoint}) ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP, Endpoint: client.opts.endpoint})
ctx = NewClientContext(ctx, ClientInfo{PathPattern: c.pathPattern, Request: req}) ctx = NewClientContext(ctx, ClientInfo{PathPattern: c.pathPattern, Request: req})
return client.invoke(ctx, req, args, reply, c) return client.invoke(ctx, req, args, reply, c)
} }
@ -246,21 +208,20 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf
h := func(ctx context.Context, in interface{}) (interface{}, error) { h := func(ctx context.Context, in interface{}) (interface{}, error) {
var done func(context.Context, balancer.DoneInfo) var done func(context.Context, balancer.DoneInfo)
if client.r != nil { if client.r != nil {
nodes := client.r.fetch(ctx) var (
if len(nodes) == 0 { err error
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", "fetch error") node *registry.ServiceInstance
} nodes = client.r.fetch(ctx)
var node *registry.ServiceInstance )
var err error if node, done, err = client.opts.balancer.Pick(ctx, c.pathPattern, nodes); err != nil {
node, done, err = client.b.Pick(ctx, c.pathPattern, nodes)
if err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
} }
req = req.Clone(ctx) scheme, addr, err := parseEndpoint(node.Endpoints)
addr, err := parseEndpoint(client.scheme, node.Endpoints)
if err != nil { if err != nil {
return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error())
} }
req = req.Clone(ctx)
req.URL.Scheme = scheme
req.URL.Host = addr req.URL.Host = addr
} }
res, err := client.do(ctx, req, c) res, err := client.do(ctx, req, c)
@ -271,13 +232,13 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if err := client.decoder(ctx, res, reply); err != nil { if err := client.opts.decoder(ctx, res, reply); err != nil {
return nil, err return nil, err
} }
return reply, nil return reply, nil
} }
if client.middleware != nil { if client.opts.middleware != nil {
h = client.middleware(h) h = client.opts.middleware(h)
} }
_, err := h(ctx, args) _, err := h(ctx, args)
return err return err
@ -300,7 +261,7 @@ func (client *Client) do(ctx context.Context, req *http.Request, c callInfo) (*h
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := client.errorDecoder(ctx, resp); err != nil { if err := client.opts.errorDecoder(ctx, resp); err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil

@ -16,16 +16,30 @@ type Target struct {
Endpoint string Endpoint string
} }
func parseTarget(endpoint string) (*Target, error) {
u, err := url.Parse(endpoint)
if err != nil {
if u, err = url.Parse("http://" + endpoint); err != nil {
return nil, err
}
}
target := &Target{Scheme: u.Scheme, Authority: u.Host}
if len(u.Path) > 1 {
target.Endpoint = u.Path[1:]
}
return target, nil
}
type resolver struct { type resolver struct {
lock sync.RWMutex lock sync.RWMutex
nodes []*registry.ServiceInstance nodes []*registry.ServiceInstance
target Target target *Target
watcher registry.Watcher watcher registry.Watcher
logger *log.Helper logger *log.Helper
} }
func newResolver(ctx context.Context, scheme string, discovery registry.Discovery, target Target) (*resolver, error) { func newResolver(ctx context.Context, discovery registry.Discovery, target *Target) (*resolver, error) {
watcher, err := discovery.Watch(ctx, target.Endpoint) watcher, err := discovery.Watch(ctx, target.Endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
@ -44,7 +58,7 @@ func newResolver(ctx context.Context, scheme string, discovery registry.Discover
} }
var nodes []*registry.ServiceInstance var nodes []*registry.ServiceInstance
for _, in := range services { for _, in := range services {
endpoint, err := parseEndpoint(scheme, in.Endpoints) _, endpoint, err := parseEndpoint(in.Endpoints)
if err != nil { if err != nil {
r.logger.Errorf("Failed to parse discovery endpoint: %v error %v", in.Endpoints, err) r.logger.Errorf("Failed to parse discovery endpoint: %v error %v", in.Endpoints, err)
continue continue
@ -68,19 +82,18 @@ func (r *resolver) fetch(ctx context.Context) []*registry.ServiceInstance {
r.lock.RLock() r.lock.RLock()
nodes := r.nodes nodes := r.nodes
r.lock.RUnlock() r.lock.RUnlock()
return nodes return nodes
} }
func parseEndpoint(schema string, endpoints []string) (string, error) { func parseEndpoint(endpoints []string) (string, string, error) {
for _, e := range endpoints { for _, e := range endpoints {
u, err := url.Parse(e) u, err := url.Parse(e)
if err != nil { if err != nil {
return "", err return "", "", err
} }
if u.Scheme == schema { if u.Scheme == "http" {
return u.Host, nil return u.Scheme, u.Host, nil
} }
} }
return "", nil return "", "", nil
} }

Loading…
Cancel
Save