package warden import ( "bytes" "context" "errors" "io/ioutil" "os" "reflect" "testing" "time" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "github.com/bilibili/kratos/pkg/log" ) func Test_logFn(t *testing.T) { type args struct { code int dt time.Duration } tests := []struct { name string args args want func(context.Context, ...log.D) }{ { name: "ok", args: args{code: 0, dt: time.Millisecond}, want: log.Infov, }, { name: "slowlog", args: args{code: 0, dt: time.Second}, want: log.Warnv, }, { name: "business error", args: args{code: 2233, dt: time.Millisecond}, want: log.Warnv, }, { name: "system error", args: args{code: -1, dt: 0}, want: log.Errorv, }, { name: "system error and slowlog", args: args{code: -1, dt: time.Second}, want: log.Errorv, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := logFn(tt.args.code, tt.args.dt); reflect.ValueOf(got).Pointer() != reflect.ValueOf(tt.want).Pointer() { t.Errorf("unexpect log function!") } }) } } func callInterceptor(err error, interceptor grpc.UnaryClientInterceptor, opts ...grpc.CallOption) { interceptor(context.Background(), "test-method", bytes.NewBufferString("test-req"), "test_reply", &grpc.ClientConn{}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return err }, opts...) } func TestClientLog(t *testing.T) { stderr, err := ioutil.TempFile(os.TempDir(), "stderr") if err != nil { t.Fatal(err) } old := os.Stderr os.Stderr = stderr t.Logf("capture stderr file: %s", stderr.Name()) t.Run("test no option", func(t *testing.T) { callInterceptor(nil, clientLogging()) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-req") assert.Contains(t, string(data), "path") assert.Contains(t, string(data), "ret") assert.Contains(t, string(data), "ts") assert.Contains(t, string(data), "grpc-access-log") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test disable args", func(t *testing.T) { callInterceptor(nil, clientLogging(WithDialLogFlag(LogFlagDisableArgs))) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.NotContains(t, string(data), "test-req") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test disable args and disable info", func(t *testing.T) { callInterceptor(nil, clientLogging(WithDialLogFlag(LogFlagDisableArgs|LogFlagDisableInfo))) callInterceptor(errors.New("test-error"), clientLogging(WithDialLogFlag(LogFlagDisableArgs|LogFlagDisableInfo))) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-error") assert.NotContains(t, string(data), "INFO") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test call option", func(t *testing.T) { callInterceptor(nil, clientLogging(), WithLogFlag(LogFlagDisableArgs)) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.NotContains(t, string(data), "test-req") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test combine option", func(t *testing.T) { interceptor := clientLogging(WithDialLogFlag(LogFlagDisableInfo)) callInterceptor(nil, interceptor, WithLogFlag(LogFlagDisableArgs)) callInterceptor(errors.New("test-error"), interceptor, WithLogFlag(LogFlagDisableArgs)) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-error") assert.NotContains(t, string(data), "INFO") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test no log", func(t *testing.T) { callInterceptor(errors.New("test error"), clientLogging(WithDialLogFlag(LogFlagDisable))) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Empty(t, data) stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test multi flag", func(t *testing.T) { interceptor := clientLogging(WithDialLogFlag(LogFlagDisableInfo | LogFlagDisableArgs)) callInterceptor(nil, interceptor) callInterceptor(errors.New("test-error"), interceptor) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-error") assert.NotContains(t, string(data), "INFO") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) os.Stderr = old } func callServerInterceptor(err error, interceptor grpc.UnaryServerInterceptor) { interceptor(context.Background(), bytes.NewBufferString("test-req"), &grpc.UnaryServerInfo{ FullMethod: "test-method", }, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, err }) } func TestServerLog(t *testing.T) { stderr, err := ioutil.TempFile(os.TempDir(), "stderr") if err != nil { t.Fatal(err) } old := os.Stderr os.Stderr = stderr t.Logf("capture stderr file: %s", stderr.Name()) t.Run("test no option", func(t *testing.T) { callServerInterceptor(nil, serverLogging(0)) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-req") assert.Contains(t, string(data), "path") assert.Contains(t, string(data), "ret") assert.Contains(t, string(data), "ts") assert.Contains(t, string(data), "grpc-access-log") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test disable args", func(t *testing.T) { callServerInterceptor(nil, serverLogging(LogFlagDisableArgs)) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.NotContains(t, string(data), "test-req") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test no log", func(t *testing.T) { callServerInterceptor(errors.New("test error"), serverLogging(LogFlagDisable)) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Empty(t, data) stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) t.Run("test multi flag", func(t *testing.T) { interceptor := serverLogging(LogFlagDisableInfo | LogFlagDisableArgs) callServerInterceptor(nil, interceptor) callServerInterceptor(errors.New("test-error"), interceptor) stderr.Seek(0, os.SEEK_SET) data, err := ioutil.ReadAll(stderr) if err != nil { t.Error(err) } assert.Contains(t, string(data), "test-method") assert.Contains(t, string(data), "test-error") assert.NotContains(t, string(data), "INFO") stderr.Seek(0, os.SEEK_SET) stderr.Truncate(0) }) os.Stderr = old }