|
|
|
package grpc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
|
|
|
|
ic "github.com/go-kratos/kratos/v2/internal/context"
|
|
|
|
"github.com/go-kratos/kratos/v2/middleware"
|
|
|
|
"github.com/go-kratos/kratos/v2/transport"
|
|
|
|
"google.golang.org/grpc"
|
|
|
|
grpcmd "google.golang.org/grpc/metadata"
|
|
|
|
)
|
|
|
|
|
|
|
|
// unaryServerInterceptor is a gRPC unary server interceptor
|
|
|
|
func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor {
|
|
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
|
|
|
ctx, cancel := ic.Merge(ctx, s.baseCtx)
|
|
|
|
defer cancel()
|
|
|
|
md, _ := grpcmd.FromIncomingContext(ctx)
|
|
|
|
replyHeader := grpcmd.MD{}
|
|
|
|
tr := &Transport{
|
|
|
|
operation: info.FullMethod,
|
|
|
|
reqHeader: headerCarrier(md),
|
|
|
|
replyHeader: headerCarrier(replyHeader),
|
|
|
|
}
|
|
|
|
if s.endpoint != nil {
|
|
|
|
tr.endpoint = s.endpoint.String()
|
|
|
|
}
|
|
|
|
ctx = transport.NewServerContext(ctx, tr)
|
|
|
|
if s.timeout > 0 {
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, s.timeout)
|
|
|
|
defer cancel()
|
|
|
|
}
|
|
|
|
h := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
|
|
return handler(ctx, req)
|
|
|
|
}
|
|
|
|
if len(s.middleware) > 0 {
|
|
|
|
h = middleware.Chain(s.middleware...)(h)
|
|
|
|
}
|
|
|
|
reply, err := h(ctx, req)
|
|
|
|
if len(replyHeader) > 0 {
|
|
|
|
_ = grpc.SetHeader(ctx, replyHeader)
|
|
|
|
}
|
|
|
|
return reply, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// wrappedStream is rewrite grpc stream's context
|
|
|
|
type wrappedStream struct {
|
|
|
|
grpc.ServerStream
|
|
|
|
ctx context.Context
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewWrappedStream(ctx context.Context, stream grpc.ServerStream) grpc.ServerStream {
|
|
|
|
return &wrappedStream{
|
|
|
|
ServerStream: stream,
|
|
|
|
ctx: ctx,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (w *wrappedStream) Context() context.Context {
|
|
|
|
return w.ctx
|
|
|
|
}
|
|
|
|
|
|
|
|
// streamServerInterceptor is a gRPC stream server interceptor
|
|
|
|
func (s *Server) streamServerInterceptor() grpc.StreamServerInterceptor {
|
|
|
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
|
|
ctx, cancel := ic.Merge(ss.Context(), s.baseCtx)
|
|
|
|
defer cancel()
|
|
|
|
md, _ := grpcmd.FromIncomingContext(ctx)
|
|
|
|
replyHeader := grpcmd.MD{}
|
|
|
|
ctx = transport.NewServerContext(ctx, &Transport{
|
|
|
|
endpoint: s.endpoint.String(),
|
|
|
|
operation: info.FullMethod,
|
|
|
|
reqHeader: headerCarrier(md),
|
|
|
|
replyHeader: headerCarrier(replyHeader),
|
|
|
|
})
|
|
|
|
|
|
|
|
ws := NewWrappedStream(ctx, ss)
|
|
|
|
|
|
|
|
err := handler(srv, ws)
|
|
|
|
if len(replyHeader) > 0 {
|
|
|
|
_ = grpc.SetHeader(ctx, replyHeader)
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|