merge master

fix/windows_path
haiyux 2 years ago
commit 1b862ccccc
  1. 2
      .github/workflows/go.yml
  2. 22
      encoding/proto/proto.go
  3. 58
      encoding/proto/proto_test.go

@ -36,7 +36,7 @@ jobs:
- "8848:8848" - "8848:8848"
- "9848:9848" - "9848:9848"
polaris: polaris:
image: polarismesh/polaris-standalone:v1.17.0 image: polarismesh/polaris-standalone:latest
ports: ports:
- 8090:8090 - 8090:8090
- 8091:8091 - 8091:8091

@ -3,6 +3,9 @@
package proto package proto
import ( import (
"errors"
"reflect"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding"
@ -23,9 +26,26 @@ func (codec) Marshal(v interface{}) ([]byte, error) {
} }
func (codec) Unmarshal(data []byte, v interface{}) error { func (codec) Unmarshal(data []byte, v interface{}) error {
return proto.Unmarshal(data, v.(proto.Message)) pm, err := getProtoMessage(v)
if err != nil {
return err
}
return proto.Unmarshal(data, pm)
} }
func (codec) Name() string { func (codec) Name() string {
return Name return Name
} }
func getProtoMessage(v interface{}) (proto.Message, error) {
if msg, ok := v.(proto.Message); ok {
return msg, nil
}
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr {
return nil, errors.New("not proto message")
}
val = val.Elem()
return getProtoMessage(val.Interface())
}

@ -44,3 +44,61 @@ func TestCodec(t *testing.T) {
t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby) t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby)
} }
} }
func TestCodec2(t *testing.T) {
c := new(codec)
model := testData.TestModel{
Id: 1,
Name: "kratos",
Hobby: []string{"study", "eat", "play"},
}
m, err := c.Marshal(&model)
if err != nil {
t.Errorf("Marshal() should be nil, but got %s", err)
}
var res testData.TestModel
rp := &res
err = c.Unmarshal(m, &rp)
if err != nil {
t.Errorf("Unmarshal() should be nil, but got %s", err)
}
if !reflect.DeepEqual(res.Id, model.Id) {
t.Errorf("ID should be %d, but got %d", res.Id, model.Id)
}
if !reflect.DeepEqual(res.Name, model.Name) {
t.Errorf("Name should be %s, but got %s", res.Name, model.Name)
}
if !reflect.DeepEqual(res.Hobby, model.Hobby) {
t.Errorf("Hobby should be %s, but got %s", res.Hobby, model.Hobby)
}
}
func Test_getProtoMessage(t *testing.T) {
p := &testData.TestModel{Id: 1}
type args struct {
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "test1", args: args{v: &testData.TestModel{}}, wantErr: false},
{name: "test2", args: args{v: testData.TestModel{}}, wantErr: true},
{name: "test3", args: args{v: &p}, wantErr: false},
{name: "test4", args: args{v: 1}, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := getProtoMessage(tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("getProtoMessage() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

Loading…
Cancel
Save