diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7d41aea50..f1effa9cb 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -13,7 +13,7 @@ jobs: build: strategy: matrix: - go: [1.18,1.19,1.20.x] + go: [1.18, 1.19, 1.20.x] name: build & test runs-on: ubuntu-latest services: @@ -36,7 +36,7 @@ jobs: - "8848:8848" - "9848:9848" polaris: - image: polarismesh/polaris-standalone:v1.17.0 + image: polarismesh/polaris-standalone:latest ports: - 8090:8090 - 8091:8091 diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index 0da5e0dcb..8c6ae58d1 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -3,6 +3,9 @@ package proto import ( + "errors" + "reflect" + "google.golang.org/protobuf/proto" "github.com/go-kratos/kratos/v2/encoding" @@ -23,9 +26,26 @@ func (codec) Marshal(v interface{}) ([]byte, error) { } func (codec) Unmarshal(data []byte, v interface{}) error { - return proto.Unmarshal(data, v.(proto.Message)) + pm, err := getProtoMessage(v) + if err != nil { + return err + } + return proto.Unmarshal(data, pm) } func (codec) Name() string { return Name } + +func getProtoMessage(v interface{}) (proto.Message, error) { + if msg, ok := v.(proto.Message); ok { + return msg, nil + } + val := reflect.ValueOf(v) + if val.Kind() != reflect.Ptr { + return nil, errors.New("not proto message") + } + + val = val.Elem() + return getProtoMessage(val.Interface()) +} diff --git a/encoding/proto/proto_test.go b/encoding/proto/proto_test.go index 1d33b3397..2c2ffece9 100644 --- a/encoding/proto/proto_test.go +++ b/encoding/proto/proto_test.go @@ -44,3 +44,61 @@ func TestCodec(t *testing.T) { t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby) } } + +func TestCodec2(t *testing.T) { + c := new(codec) + + model := testData.TestModel{ + Id: 1, + Name: "kratos", + Hobby: []string{"study", "eat", "play"}, + } + + m, err := c.Marshal(&model) + if err != nil { + t.Errorf("Marshal() should be nil, but got %s", err) + } + + var res testData.TestModel + rp := &res + + err = c.Unmarshal(m, &rp) + if err != nil { + t.Errorf("Unmarshal() should be nil, but got %s", err) + } + if !reflect.DeepEqual(res.Id, model.Id) { + t.Errorf("ID should be %d, but got %d", res.Id, model.Id) + } + if !reflect.DeepEqual(res.Name, model.Name) { + t.Errorf("Name should be %s, but got %s", res.Name, model.Name) + } + if !reflect.DeepEqual(res.Hobby, model.Hobby) { + t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby) + } +} + +func Test_getProtoMessage(t *testing.T) { + p := &testData.TestModel{Id: 1} + type args struct { + v interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "test1", args: args{v: &testData.TestModel{}}, wantErr: false}, + {name: "test2", args: args{v: testData.TestModel{}}, wantErr: true}, + {name: "test3", args: args{v: &p}, wantErr: false}, + {name: "test4", args: args{v: 1}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getProtoMessage(tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("getProtoMessage() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +}