diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index 7fd9cc00f..e49987c42 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -24,7 +24,6 @@ func callHTTP() { recovery.Recovery(), ), transhttp.WithEndpoint("127.0.0.1:8000"), - transhttp.WithScheme("http"), ) if err != nil { log.Fatal(err) diff --git a/examples/registry/consul/client/main.go b/examples/registry/consul/client/main.go index dfb7d5719..0bff724a0 100644 --- a/examples/registry/consul/client/main.go +++ b/examples/registry/consul/client/main.go @@ -47,7 +47,6 @@ func callHTTP(cli *api.Client) { transhttp.WithMiddleware( recovery.Recovery(), ), - transhttp.WithScheme("http"), transhttp.WithEndpoint("discovery:///helloworld"), transhttp.WithDiscovery(r), ) diff --git a/transport/http/balancer/random/random.go b/transport/http/balancer/random/random.go index f56e66dfd..9ac93c8d9 100644 --- a/transport/http/balancer/random/random.go +++ b/transport/http/balancer/random/random.go @@ -21,7 +21,8 @@ func New() *Balancer { func (b *Balancer) Pick(ctx context.Context, pathPattern string, nodes []*registry.ServiceInstance) (node *registry.ServiceInstance, done func(context.Context, balancer.DoneInfo), err error) { if len(nodes) == 0 { return nil, nil, fmt.Errorf("no instances avaiable") - } else if len(nodes) == 1 { + } + if len(nodes) == 1 { return nodes[0], func(context.Context, balancer.DoneInfo) {}, nil } idx := rand.Intn(len(nodes)) diff --git a/transport/http/client.go b/transport/http/client.go index 65351b5e2..c7bc91b5b 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -7,7 +7,6 @@ import ( "io" "io/ioutil" "net/http" - "net/url" "time" "github.com/go-kratos/kratos/v2/encoding" @@ -20,23 +19,6 @@ import ( "github.com/go-kratos/kratos/v2/transport/http/balancer/random" ) -// Client is an HTTP client. -type Client struct { - cc *http.Client - r *resolver - b balancer.Balancer - - scheme string - endpoint string - target Target - userAgent string - middleware middleware.Middleware - encoder EncodeRequestFunc - decoder DecodeResponseFunc - errorDecoder DecodeErrorFunc - discovery registry.Discovery -} - // DecodeErrorFunc is decode error func. type DecodeErrorFunc func(ctx context.Context, res *http.Response) error @@ -49,6 +31,21 @@ type DecodeResponseFunc func(ctx context.Context, res *http.Response, out interf // ClientOption is HTTP client option. type ClientOption func(*clientOptions) +// Client is an HTTP transport client. +type clientOptions struct { + ctx context.Context + timeout time.Duration + endpoint string + userAgent string + encoder EncodeRequestFunc + decoder DecodeResponseFunc + errorDecoder DecodeErrorFunc + transport http.RoundTripper + balancer balancer.Balancer + discovery registry.Discovery + middleware middleware.Middleware +} + // WithTransport with client transport. func WithTransport(trans http.RoundTripper) ClientOption { return func(o *clientOptions) { @@ -77,13 +74,6 @@ func WithMiddleware(m ...middleware.Middleware) ClientOption { } } -// WithScheme with client schema. -func WithScheme(scheme string) ClientOption { - return func(o *clientOptions) { - o.scheme = scheme - } -} - // WithEndpoint with client addr. func WithEndpoint(endpoint string) ClientOption { return func(o *clientOptions) { @@ -128,27 +118,18 @@ func WithBalancer(b balancer.Balancer) ClientOption { } } -// Client is an HTTP transport client. -type clientOptions struct { - ctx context.Context - transport http.RoundTripper - middleware middleware.Middleware - timeout time.Duration - scheme string - endpoint string - userAgent string - encoder EncodeRequestFunc - decoder DecodeResponseFunc - errorDecoder DecodeErrorFunc - discovery registry.Discovery - balancer balancer.Balancer +// Client is an HTTP client. +type Client struct { + opts clientOptions + target *Target + r *resolver + cc *http.Client } // NewClient returns an HTTP client. func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { - options := &clientOptions{ + options := clientOptions{ ctx: ctx, - scheme: "http", timeout: 500 * time.Millisecond, encoder: DefaultRequestEncoder, decoder: DefaultResponseDecoder, @@ -157,49 +138,30 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { balancer: random.New(), } for _, o := range opts { - o(options) + o(&options) } - target := Target{ - Scheme: options.scheme, - Endpoint: options.endpoint, + target, err := parseTarget(options.endpoint) + if err != nil { + return nil, err } var r *resolver - if options.endpoint != "" && options.discovery != nil { - u, err := url.Parse(options.endpoint) - if err != nil { - u, err = url.Parse("http://" + options.endpoint) - if err != nil { - return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) - } - } - if u.Scheme == "discovery" && len(u.Path) > 1 { - target = Target{ - Scheme: u.Scheme, - Authority: u.Host, - Endpoint: u.Path[1:], - } - r, err = newResolver(ctx, options.scheme, options.discovery, target) - if err != nil { + if target.Endpoint != "" && options.discovery != nil { + if target.Scheme == "discovery" { + if r, err = newResolver(ctx, options.discovery, target); err != nil { return nil, fmt.Errorf("[http client] new resolver failed!err: %v", options.endpoint) } } else { return nil, fmt.Errorf("[http client] invalid endpoint format: %v", options.endpoint) } } - return &Client{ - cc: &http.Client{Timeout: options.timeout, Transport: options.transport}, - r: r, - encoder: options.encoder, - decoder: options.decoder, - errorDecoder: options.errorDecoder, - middleware: options.middleware, - userAgent: options.userAgent, - target: target, - scheme: options.scheme, - endpoint: options.endpoint, - discovery: options.discovery, - b: options.balancer, + opts: options, + target: target, + r: r, + cc: &http.Client{ + Timeout: options.timeout, + Transport: options.transport, + }, }, nil } @@ -220,13 +182,13 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, body []byte err error ) - contentType, body, err = client.encoder(ctx, args) + contentType, body, err = client.opts.encoder(ctx, args) if err != nil { return err } reqBody = bytes.NewReader(body) } - url := fmt.Sprintf("%s://%s%s", client.scheme, client.target.Endpoint, path) + url := fmt.Sprintf("%s://%s%s", client.target.Scheme, client.target.Authority, path) req, err := http.NewRequest(c.method, url, reqBody) if err != nil { return err @@ -234,10 +196,10 @@ func (client *Client) Invoke(ctx context.Context, path string, args interface{}, if contentType != "" { req.Header.Set("Content-Type", contentType) } - if client.userAgent != "" { - req.Header.Set("User-Agent", client.userAgent) + if client.opts.userAgent != "" { + req.Header.Set("User-Agent", client.opts.userAgent) } - ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP, Endpoint: client.endpoint}) + ctx = transport.NewContext(ctx, transport.Transport{Kind: transport.KindHTTP, Endpoint: client.opts.endpoint}) ctx = NewClientContext(ctx, ClientInfo{PathPattern: c.pathPattern, Request: req}) return client.invoke(ctx, req, args, reply, c) } @@ -246,21 +208,20 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf h := func(ctx context.Context, in interface{}) (interface{}, error) { var done func(context.Context, balancer.DoneInfo) if client.r != nil { - nodes := client.r.fetch(ctx) - if len(nodes) == 0 { - return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", "fetch error") - } - var node *registry.ServiceInstance - var err error - node, done, err = client.b.Pick(ctx, c.pathPattern, nodes) - if err != nil { + var ( + err error + node *registry.ServiceInstance + nodes = client.r.fetch(ctx) + ) + if node, done, err = client.opts.balancer.Pick(ctx, c.pathPattern, nodes); err != nil { return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } - req = req.Clone(ctx) - addr, err := parseEndpoint(client.scheme, node.Endpoints) + scheme, addr, err := parseEndpoint(node.Endpoints) if err != nil { return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) } + req = req.Clone(ctx) + req.URL.Scheme = scheme req.URL.Host = addr } res, err := client.do(ctx, req, c) @@ -271,13 +232,13 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf return nil, err } defer res.Body.Close() - if err := client.decoder(ctx, res, reply); err != nil { + if err := client.opts.decoder(ctx, res, reply); err != nil { return nil, err } return reply, nil } - if client.middleware != nil { - h = client.middleware(h) + if client.opts.middleware != nil { + h = client.opts.middleware(h) } _, err := h(ctx, args) return err @@ -300,7 +261,7 @@ func (client *Client) do(ctx context.Context, req *http.Request, c callInfo) (*h if err != nil { return nil, err } - if err := client.errorDecoder(ctx, resp); err != nil { + if err := client.opts.errorDecoder(ctx, resp); err != nil { return nil, err } return resp, nil diff --git a/transport/http/resolver.go b/transport/http/resolver.go index 2f605da49..c120f6f7e 100644 --- a/transport/http/resolver.go +++ b/transport/http/resolver.go @@ -16,16 +16,30 @@ type Target struct { Endpoint string } +func parseTarget(endpoint string) (*Target, error) { + u, err := url.Parse(endpoint) + if err != nil { + if u, err = url.Parse("http://" + endpoint); err != nil { + return nil, err + } + } + target := &Target{Scheme: u.Scheme, Authority: u.Host} + if len(u.Path) > 1 { + target.Endpoint = u.Path[1:] + } + return target, nil +} + type resolver struct { lock sync.RWMutex nodes []*registry.ServiceInstance - target Target + target *Target watcher registry.Watcher logger *log.Helper } -func newResolver(ctx context.Context, scheme string, discovery registry.Discovery, target Target) (*resolver, error) { +func newResolver(ctx context.Context, discovery registry.Discovery, target *Target) (*resolver, error) { watcher, err := discovery.Watch(ctx, target.Endpoint) if err != nil { return nil, err @@ -44,7 +58,7 @@ func newResolver(ctx context.Context, scheme string, discovery registry.Discover } var nodes []*registry.ServiceInstance for _, in := range services { - endpoint, err := parseEndpoint(scheme, in.Endpoints) + _, endpoint, err := parseEndpoint(in.Endpoints) if err != nil { r.logger.Errorf("Failed to parse discovery endpoint: %v error %v", in.Endpoints, err) continue @@ -68,19 +82,18 @@ func (r *resolver) fetch(ctx context.Context) []*registry.ServiceInstance { r.lock.RLock() nodes := r.nodes r.lock.RUnlock() - return nodes } -func parseEndpoint(schema string, endpoints []string) (string, error) { +func parseEndpoint(endpoints []string) (string, string, error) { for _, e := range endpoints { u, err := url.Parse(e) if err != nil { - return "", err + return "", "", err } - if u.Scheme == schema { - return u.Host, nil + if u.Scheme == "http" { + return u.Scheme, u.Host, nil } } - return "", nil + return "", "", nil }