add response header (#1119)

* add response header

Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
pull/1125/head
longxboy 3 years ago committed by GitHub
parent 493c11929f
commit 545ffd1084
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 55
      examples/header/client/main.go
  2. 60
      examples/header/server/main.go
  3. 4
      examples/metadata/client/main.go
  4. 3
      examples/metadata/server/main.go
  5. 24
      internal/testproto/echo_service_test.go
  6. 8
      middleware/logging/logging_test.go
  7. 12
      middleware/metadata/metadata.go
  8. 17
      middleware/metadata/metadata_test.go
  9. 4
      middleware/tracing/tracing.go
  10. 15
      middleware/tracing/tracing_test.go
  11. 7
      transport/grpc/client.go
  12. 14
      transport/grpc/server.go
  13. 18
      transport/grpc/transport.go
  14. 24
      transport/http/calloption.go
  15. 12
      transport/http/client.go
  16. 3
      transport/http/server.go
  17. 20
      transport/http/transport.go
  18. 12
      transport/transport.go

@ -0,0 +1,55 @@
package main
import (
"context"
"log"
stdhttp "net/http"
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http"
stdgrpc "google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
)
func main() {
callHTTP()
callGRPC()
}
func callHTTP() {
conn, err := http.NewClient(
context.Background(),
http.WithEndpoint("127.0.0.1:8000"),
)
if err != nil {
panic(err)
}
client := helloworld.NewGreeterHTTPClient(conn)
ctx := context.Background()
var header stdhttp.Header
reply, err := client.SayHello(ctx, &helloworld.HelloRequest{Name: "kratos"}, http.Header(&header))
if err != nil {
log.Fatal(err)
}
log.Printf("[http] SayHello %s header: %v\n", reply.Message, header)
}
func callGRPC() {
conn, err := grpc.DialInsecure(
context.Background(),
grpc.WithEndpoint("127.0.0.1:9000"),
)
if err != nil {
log.Fatal(err)
}
client := helloworld.NewGreeterClient(conn)
ctx := context.Background()
var md grpcmd.MD
reply, err := client.SayHello(ctx, &helloworld.HelloRequest{Name: "kratos"}, stdgrpc.Header(&md))
if err != nil {
log.Fatal(err)
}
log.Printf("[grpc] SayHello %+v header: %v\n", reply, md)
}

@ -0,0 +1,60 @@
package main
import (
"context"
"fmt"
"log"
"github.com/go-kratos/kratos/examples/helloworld/helloworld"
"github.com/go-kratos/kratos/v2"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http"
)
// go build -ldflags "-X main.Version=x.y.z"
var (
// Name is the name of the compiled software.
Name = "helloworld"
// Version is the version of the compiled software.
Version = "v1.0.0"
)
// server is used to implement helloworld.GreeterServer.
type server struct {
helloworld.UnimplementedGreeterServer
}
// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
info, _ := kratos.FromContext(ctx)
if tr, ok := transport.FromServerContext(ctx); ok {
tr.ReplyHeader().Set("app_name", info.Name())
}
return &helloworld.HelloReply{Message: fmt.Sprintf("Hello %s", in.Name)}, nil
}
func main() {
grpcSrv := grpc.NewServer(
grpc.Address(":9000"),
)
httpSrv := http.NewServer(
http.Address(":8000"),
)
s := &server{}
helloworld.RegisterGreeterServer(grpcSrv, s)
helloworld.RegisterGreeterHTTPServer(httpSrv, s)
app := kratos.New(
kratos.Name(Name),
kratos.Server(
httpSrv,
grpcSrv,
),
)
if err := app.Run(); err != nil {
log.Fatal(err)
}
}

@ -34,7 +34,7 @@ func callHTTP() {
if err != nil {
log.Fatal(err)
}
log.Printf("[http] SayHello %s\n", reply.Message)
log.Printf("[http] SayHello %s\n", reply)
}
func callGRPC() {
@ -55,5 +55,5 @@ func callGRPC() {
if err != nil {
log.Fatal(err)
}
log.Printf("[grpc] SayHello %+v\n", reply)
log.Printf("[grpc] SayHello %+v \n", reply)
}

@ -32,8 +32,7 @@ func (s *server) SayHello(ctx context.Context, in *helloworld.HelloRequest) (*he
if md, ok := metadata.FromServerContext(ctx); ok {
extra = md.Get("x-md-global-extra")
}
info, _ := kratos.FromContext(ctx)
return &helloworld.HelloReply{Message: fmt.Sprintf("Hello %s extra: %s name: %s", in.Name, extra, info.Name())}, nil
return &helloworld.HelloReply{Message: fmt.Sprintf("Hello %s extra_meta: %s", in.Name, extra)}, nil
}
func main() {

@ -4,16 +4,20 @@ import (
context "context"
"errors"
"fmt"
stdhttp "net/http"
"testing"
"time"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/metadata"
mmd "github.com/go-kratos/kratos/v2/middleware/metadata"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/grpc"
"github.com/go-kratos/kratos/v2/transport/http"
_struct "github.com/golang/protobuf/ptypes/struct"
stdgrpc "google.golang.org/grpc"
grpcmd "google.golang.org/grpc/metadata"
)
var md = metadata.Metadata{"x-md-global-test": "test_value"}
@ -27,6 +31,10 @@ func (s *echoService) Echo(ctx context.Context, m *SimpleMessage) (*SimpleMessag
if v := md.Get("x-md-global-test"); v != "test_value" {
return nil, errors.New("md not match" + v)
}
if tr, ok := transport.FromServerContext(ctx); ok {
tr.ReplyHeader().Set("2233", "niang")
}
return m, nil
}
@ -51,8 +59,8 @@ type echoClient struct {
}
// post: /v1/example/echo/{id}
func (c *echoClient) Echo(ctx context.Context, in *SimpleMessage) (out *SimpleMessage, err error) {
return c.client.Echo(ctx, in)
func (c *echoClient) Echo(ctx context.Context, in *SimpleMessage, opts ...http.CallOption) (out *SimpleMessage, err error) {
return c.client.Echo(ctx, in, opts...)
}
// post: /v1/example/echo_body
@ -138,9 +146,13 @@ func testEchoHTTPClient(t *testing.T, addr string) {
ctx := context.Background()
ctx = metadata.NewClientContext(ctx, md)
if out, err = cli.Echo(ctx, in); err != nil {
var header stdhttp.Header
if out, err = cli.Echo(ctx, in, http.Header(&header)); err != nil {
t.Fatal(err)
}
if header.Get("2233") != "niang" {
t.Errorf("[echo] header key 2233 expected niang got %v", header.Get("2233"))
}
check("echo", &SimpleMessage{Id: "test_id"}, out)
if out, err = cli.EchoBody(context.Background(), in); err != nil {
@ -211,9 +223,13 @@ func testEchoGRPCClient(t *testing.T, addr string) {
client := NewEchoServiceClient(cc)
ctx := context.Background()
ctx = metadata.NewClientContext(ctx, md)
if out, err = client.Echo(ctx, in); err != nil {
var md grpcmd.MD
if out, err = client.Echo(ctx, in, stdgrpc.Header(&md)); err != nil {
t.Fatal(err)
}
if len(md.Get("2233")) != 1 || md.Get("2233")[0] != "niang" {
t.Errorf("[echo] header key 2233 expected niang got %v", md.Get("2233"))
}
if in.Id != out.Id || in.Num != out.Num {
t.Errorf("expected %v got %v", in, out)
}

@ -24,16 +24,16 @@ type Transport struct {
func (tr *Transport) Kind() transport.Kind {
return tr.kind
}
func (tr *Transport) Endpoint() string {
return tr.endpoint
}
func (tr *Transport) Operation() string {
return tr.operation
}
func (tr *Transport) Header() transport.Header {
func (tr *Transport) RequestHeader() transport.Header {
return nil
}
func (tr *Transport) ReplyHeader() transport.Header {
return nil
}

@ -53,9 +53,10 @@ func Server(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
md := options.md.Clone()
for _, k := range tr.Header().Keys() {
header := tr.RequestHeader()
for _, k := range header.Keys() {
if options.hasPrefix(k) {
md.Set(k, tr.Header().Get(k))
md.Set(k, header.Get(k))
}
}
ctx = metadata.NewServerContext(ctx, md)
@ -76,20 +77,21 @@ func Client(opts ...Option) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromClientContext(ctx); ok {
header := tr.RequestHeader()
// x-md-local-
for k, v := range options.md {
tr.Header().Set(k, v)
header.Set(k, v)
}
if md, ok := metadata.FromClientContext(ctx); ok {
for k, v := range md {
tr.Header().Set(k, v)
header.Set(k, v)
}
}
// x-md-global-
if md, ok := metadata.FromServerContext(ctx); ok {
for k, v := range md {
if options.hasPrefix(k) {
tr.Header().Set(k, v)
header.Set(k, v)
}
}
}

@ -27,10 +27,11 @@ func (hc headerCarrier) Keys() []string {
type testTransport struct{ header headerCarrier }
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP }
func (tr *testTransport) Endpoint() string { return "" }
func (tr *testTransport) Operation() string { return "" }
func (tr *testTransport) Header() transport.Header { return tr.header }
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP }
func (tr *testTransport) Endpoint() string { return "" }
func (tr *testTransport) Operation() string { return "" }
func (tr *testTransport) RequestHeader() transport.Header { return tr.header }
func (tr *testTransport) ReplyHeader() transport.Header { return tr.header }
func TestSever(t *testing.T) {
var (
@ -89,16 +90,16 @@ func TestClient(t *testing.T) {
if !ok {
return nil, errors.New("no md")
}
if tr.Header().Get(constKey) != constValue {
if tr.RequestHeader().Get(constKey) != constValue {
return nil, errors.New("const not equal")
}
if tr.Header().Get(customKey) != customValue {
if tr.RequestHeader().Get(customKey) != customValue {
return nil, errors.New("custom not equal")
}
if tr.Header().Get(globalKey) != globalValue {
if tr.RequestHeader().Get(globalKey) != globalValue {
return nil, errors.New("global not equal")
}
if tr.Header().Get(localKey) != "" {
if tr.RequestHeader().Get(localKey) != "" {
return nil, errors.New("local must empty")
}
return in, nil

@ -38,7 +38,7 @@ func Server(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
var span trace.Span
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.Header())
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.RequestHeader())
defer func() { tracer.End(ctx, span, err) }()
}
return handler(ctx, req)
@ -53,7 +53,7 @@ func Client(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromClientContext(ctx); ok {
var span trace.Span
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.Header())
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.RequestHeader())
defer func() { tracer.End(ctx, span, err) }()
}
return handler(ctx, req)

@ -43,10 +43,11 @@ type Transport struct {
header headerCarrier
}
func (tr *Transport) Kind() transport.Kind { return tr.kind }
func (tr *Transport) Endpoint() string { return tr.endpoint }
func (tr *Transport) Operation() string { return tr.operation }
func (tr *Transport) Header() transport.Header { return tr.header }
func (tr *Transport) Kind() transport.Kind { return tr.kind }
func (tr *Transport) Endpoint() string { return tr.endpoint }
func (tr *Transport) Operation() string { return tr.operation }
func (tr *Transport) RequestHeader() transport.Header { return tr.header }
func (tr *Transport) ReplyHeader() transport.Header { return tr.header }
func TestTracing(t *testing.T) {
var carrier = headerCarrier{}
@ -56,21 +57,21 @@ func TestTracing(t *testing.T) {
tracer := NewTracer(trace.SpanKindClient, WithTracerProvider(tp), WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})))
ts := &Transport{kind: transport.KindHTTP, header: carrier}
ctx, aboveSpan := tracer.Start(transport.NewClientContext(context.Background(), ts), ts.Kind().String(), ts.Operation(), ts.Header())
ctx, aboveSpan := tracer.Start(transport.NewClientContext(context.Background(), ts), ts.Kind().String(), ts.Operation(), ts.RequestHeader())
defer tracer.End(ctx, aboveSpan, nil)
// server use Extract fetch traceInfo from carrier
tracer = NewTracer(trace.SpanKindServer, WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})))
ts = &Transport{kind: transport.KindHTTP, header: carrier}
ctx, span := tracer.Start(transport.NewServerContext(ctx, ts), ts.Kind().String(), ts.Operation(), ts.Header())
ctx, span := tracer.Start(transport.NewServerContext(ctx, ts), ts.Kind().String(), ts.Operation(), ts.RequestHeader())
defer tracer.End(ctx, span, nil)
if aboveSpan.SpanContext().TraceID() != span.SpanContext().TraceID() {
t.Fatalf("TraceID failed to deliver")
}
if v, ok := transport.FromClientContext(ctx); !ok || len(v.Header().Keys()) == 0 {
if v, ok := transport.FromClientContext(ctx); !ok || len(v.RequestHeader().Keys()) == 0 {
t.Fatalf("traceHeader failed to deliver")
}
}

@ -116,7 +116,7 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration) g
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: cc.Target(),
operation: method,
header: headerCarrier{},
reqHeader: headerCarrier{},
})
if timeout > 0 {
var cancel context.CancelFunc
@ -125,10 +125,11 @@ func unaryClientInterceptor(ms []middleware.Middleware, timeout time.Duration) g
}
h := func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromClientContext(ctx); ok {
keys := tr.Header().Keys()
header := tr.RequestHeader()
keys := header.Keys()
keyvals := make([]string, 0, len(keys))
for _, k := range keys {
keyvals = append(keyvals, k, tr.Header().Get(k))
keyvals = append(keyvals, k, header.Get(k))
}
ctx = grpcmd.AppendToOutgoingContext(ctx, keyvals...)
}

@ -177,10 +177,12 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
ctx, cancel := ic.Merge(ctx, s.ctx)
defer cancel()
md, _ := grpcmd.FromIncomingContext(ctx)
replyHeader := grpcmd.MD{}
ctx = transport.NewServerContext(ctx, &Transport{
endpoint: s.endpoint.String(),
operation: info.FullMethod,
header: headerCarrier(md),
endpoint: s.endpoint.String(),
operation: info.FullMethod,
reqHeader: headerCarrier(md),
replyHeader: headerCarrier(replyHeader),
})
if s.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, s.timeout)
@ -192,6 +194,10 @@ func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
if len(s.middleware) > 0 {
h = middleware.Chain(s.middleware...)(h)
}
return h(ctx, req)
reply, err := h(ctx, req)
if len(replyHeader) > 0 {
grpc.SetHeader(ctx, replyHeader)
}
return reply, err
}
}

@ -11,9 +11,10 @@ var (
// Transport is a gRPC transport.
type Transport struct {
endpoint string
operation string
header headerCarrier
endpoint string
operation string
reqHeader headerCarrier
replyHeader headerCarrier
}
// Kind returns the transport kind.
@ -31,9 +32,14 @@ func (tr *Transport) Operation() string {
return tr.operation
}
// Header returns the transport header.
func (tr *Transport) Header() transport.Header {
return tr.header
// RequestHeader returns the request header.
func (tr *Transport) RequestHeader() transport.Header {
return tr.reqHeader
}
// ReplyHeader returns the reply header.
func (tr *Transport) ReplyHeader() transport.Header {
return tr.replyHeader
}
type headerCarrier metadata.MD

@ -1,5 +1,7 @@
package http
import "net/http"
// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type CallOption interface {
@ -26,7 +28,9 @@ type EmptyCallOption struct{}
func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo, *csAttempt) {}
type csAttempt struct{}
type csAttempt struct {
res *http.Response
}
// ContentType with request content type.
func ContentType(contentType string) CallOption {
@ -83,3 +87,21 @@ func (o PathTemplateCallOption) before(c *callInfo) error {
c.pathTemplate = o.Pattern
return nil
}
// Header returns a CallOptions that retrieves the http response header
// from server reply.
func Header(header *http.Header) CallOption {
return HeaderCallOption{header: header}
}
// HeaderCallOption is retrive response header for client call
type HeaderCallOption struct {
EmptyCallOption
header *http.Header
}
func (o HeaderCallOption) after(c *callInfo, cs *csAttempt) {
if cs.res != nil && cs.res.Header != nil {
*o.header = cs.res.Header
}
}

@ -199,15 +199,15 @@ func (client *Client) Invoke(ctx context.Context, method, path string, args inte
}
ctx = transport.NewClientContext(ctx, &Transport{
endpoint: client.opts.endpoint,
header: headerCarrier(req.Header),
reqHeader: headerCarrier(req.Header),
operation: c.operation,
request: req,
pathTemplate: c.pathTemplate,
})
return client.invoke(ctx, req, args, reply, c)
return client.invoke(ctx, req, args, reply, c, opts...)
}
func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo) error {
func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo, opts ...CallOption) error {
h := func(ctx context.Context, in interface{}) (interface{}, error) {
var done func(context.Context, balancer.DoneInfo)
if client.r != nil {
@ -230,6 +230,12 @@ func (client *Client) invoke(ctx context.Context, req *http.Request, args interf
if done != nil {
done(ctx, balancer.DoneInfo{Err: err})
}
if res != nil {
cs := csAttempt{res: res}
for _, o := range opts {
o.after(&c, &cs)
}
}
if err != nil {
return nil, err
}

@ -169,7 +169,8 @@ func (s *Server) filter() mux.MiddlewareFunc {
tr := &Transport{
endpoint: s.endpoint.String(),
operation: pathTemplate,
header: headerCarrier(req.Header),
reqHeader: headerCarrier(req.Header),
replyHeader: headerCarrier(w.Header()),
request: req,
pathTemplate: pathTemplate,
}

@ -15,7 +15,8 @@ var (
type Transport struct {
endpoint string
operation string
header headerCarrier
reqHeader headerCarrier
replyHeader headerCarrier
request *http.Request
pathTemplate string
}
@ -35,16 +36,21 @@ func (tr *Transport) Operation() string {
return tr.operation
}
// Header returns the transport header.
func (tr *Transport) Header() transport.Header {
return tr.header
}
// Request returns the transport request.
// Request returns the HTTP request.
func (tr *Transport) Request() *http.Request {
return tr.request
}
// RequestHeader returns the request header.
func (tr *Transport) RequestHeader() transport.Header {
return tr.reqHeader
}
// ReplyHeader returns the reply header.
func (tr *Transport) ReplyHeader() transport.Header {
return tr.replyHeader
}
// PathTemplate returns the http path template.
func (tr *Transport) PathTemplate() string {
return tr.pathTemplate

@ -40,17 +40,21 @@ type Transporter interface {
// Service full method selector generated by protobuf
// example: /helloworld.Greeter/SayHello
Operation() string
// request header
// http: http.Header
// grpc: metadata.MD
Header() Header
RequestHeader() Header
// reply header
// only valid for server transport
// http: http.Header
// grpc: metadata.MD
ReplyHeader() Header
}
// Kind defines the type of Transport
type Kind string
func (k Kind) String() string {
return string(k)
}
func (k Kind) String() string { return string(k) }
// Defines a set of transport kind
const (

Loading…
Cancel
Save