diff --git a/config/config.go b/config/config.go index ccefd7f46..c6f5677d4 100644 --- a/config/config.go +++ b/config/config.go @@ -2,12 +2,10 @@ package config import ( "errors" - "fmt" "reflect" "sync" "time" - "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/log" // init encoding @@ -50,17 +48,9 @@ type config struct { // New new a config with options. func New(opts ...Option) Config { options := options{ - logger: log.DefaultLogger, - decoder: func(src *KeyValue, target map[string]interface{}) error { - if src.Format == "" { - target[src.Key] = src.Value - return nil - } - if codec := encoding.GetCodec(src.Format); codec != nil { - return codec.Unmarshal(src.Value, &target) - } - return fmt.Errorf("unsupported key: %s format: %s", src.Key, src.Format) - }, + logger: log.DefaultLogger, + decoder: defaultDecoder, + resolver: defaultResolver, } for _, o := range opts { o(&options) diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 000000000..f9bb5a7b7 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,141 @@ +package config + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultResolver(t *testing.T) { + var ( + portString = "8080" + countInt = 10 + enableBool = true + rateFloat = 0.9 + ) + + data := map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": map[string]interface{}{ + "notexist": "${NOTEXIST:100}", + "port": "${PORT:8081}", + "count": "${COUNT:0}", + "enable": "${ENABLE:false}", + "rate": "${RATE}", + "empty": "${EMPTY:foobar}", + "array": []interface{}{"${PORT}", "${NOTEXIST:8081}"}, + "value1": "${test.value}", + "value2": "$PORT", + "value3": "$PORT:default", + }, + }, + "test": map[string]interface{}{ + "value": "foobar", + }, + "PORT": "8080", + "COUNT": "10", + "ENABLE": "true", + "RATE": "0.9", + "EMPTY": "", + } + + tests := []struct { + name string + path string + expect interface{} + }{ + { + name: "test not exist int env with default", + path: "foo.bar.notexist", + expect: 100, + }, + { + name: "test string with default", + path: "foo.bar.port", + expect: portString, + }, + { + name: "test int with default", + path: "foo.bar.count", + expect: countInt, + }, + { + name: "test bool with default", + path: "foo.bar.enable", + expect: enableBool, + }, + { + name: "test float without default", + path: "foo.bar.rate", + expect: rateFloat, + }, + { + name: "test empty value with default", + path: "foo.bar.empty", + expect: "", + }, + { + name: "test array", + path: "foo.bar.array", + expect: []interface{}{portString, "8081"}, + }, + { + name: "test ${test.value}", + path: "foo.bar.value1", + expect: "foobar", + }, + { + name: "test $value", + path: "foo.bar.value2", + expect: portString, + }, + { + name: "test $value:default", + path: "foo.bar.value3", + expect: portString + ":default", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := defaultResolver(data) + assert.NoError(t, err) + rd := reader{ + values: data, + } + if v, ok := rd.Value(test.path); ok { + var actual interface{} + switch test.expect.(type) { + case int: + if actual, err = v.Int(); err == nil { + assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") + } + case string: + if actual, err = v.String(); err == nil { + assert.Equal(t, test.expect, actual, "string value should be equal") + } + case bool: + if actual, err = v.Bool(); err == nil { + assert.Equal(t, test.expect, actual, "bool value should be equal") + } + case float64: + if actual, err = v.Float(); err == nil { + assert.Equal(t, test.expect, actual, "float64 value should be equal") + } + default: + actual = v.Load() + if !reflect.DeepEqual(test.expect, actual) { + t.Logf("expect: %#v, actural: %#v", test.expect, actual) + t.Fail() + } + } + if err != nil { + t.Error(err) + } + } else { + t.Error("value path not found") + } + }) + } +} diff --git a/config/options.go b/config/options.go index 5bc9ebf0b..278698b34 100644 --- a/config/options.go +++ b/config/options.go @@ -1,19 +1,28 @@ package config import ( + "fmt" + "os" + "strings" + + "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/log" ) // Decoder is config decoder. type Decoder func(*KeyValue, map[string]interface{}) error +// Resolver resolve placeholder in config. +type Resolver func(map[string]interface{}) error + // Option is config option. type Option func(*options) type options struct { - sources []Source - decoder Decoder - logger log.Logger + sources []Source + decoder Decoder + resolver Resolver + logger log.Logger } // WithSource with config source. @@ -30,9 +39,67 @@ func WithDecoder(d Decoder) Option { } } +// WithResolver with config resolver. +func WithResolver(r Resolver) Option { + return func(o *options) { + o.resolver = r + } +} + // WithLogger with config logger. func WithLogger(l log.Logger) Option { return func(o *options) { o.logger = l } } + +// defaultDecoder decode config from source KeyValue +// to target map[string]interface{} using src.Format codec. +func defaultDecoder(src *KeyValue, target map[string]interface{}) error { + if src.Format == "" { + target[src.Key] = src.Value + return nil + } + if codec := encoding.GetCodec(src.Format); codec != nil { + return codec.Unmarshal(src.Value, &target) + } + return fmt.Errorf("unsupported key: %s format: %s", src.Key, src.Format) +} + +// defaultResolver resolve placeholder in map value, +// placeholder format in ${key:default} or $key. +func defaultResolver(input map[string]interface{}) error { + mapper := func(name string) string { + args := strings.Split(strings.TrimSpace(name), ":") + if v, has := readValue(input, args[0]); has { + s, _ := v.String() + return s + } else if len(args) > 1 { // default value + return args[1] + } + return "" + } + + var resolve func(map[string]interface{}) error + resolve = func(sub map[string]interface{}) error { + for k, v := range sub { + switch vt := v.(type) { + case string: + sub[k] = os.Expand(vt, mapper) + case map[string]interface{}: + if err := resolve(vt); err != nil { + return err + } + case []interface{}: + for i, iface := range vt { + if s, ok := iface.(string); ok { + vt[i] = os.Expand(s, mapper) + } + } + sub[k] = vt + } + } + return nil + } + return resolve(input) +} diff --git a/config/reader.go b/config/reader.go index 71806a3ca..2dc778c3e 100644 --- a/config/reader.go +++ b/config/reader.go @@ -43,6 +43,9 @@ func (r *reader) Merge(kvs ...*KeyValue) error { return err } } + if err := r.opts.resolver(merged); err != nil { + return err + } r.values = merged return nil } diff --git a/config/reader_test.go b/config/reader_test.go index d7c3d179b..2cec467af 100644 --- a/config/reader_test.go +++ b/config/reader_test.go @@ -20,6 +20,7 @@ func TestReader_Merge(t *testing.T) { } return fmt.Errorf("unsupported key: %s format: %s", kv.Key, kv.Format) }, + resolver: defaultResolver, } r := newReader(opts) err = r.Merge(&KeyValue{ @@ -55,10 +56,6 @@ func TestReader_Merge(t *testing.T) { } func TestReader_Value(t *testing.T) { - var ( - err error - ok bool - ) opts := options{ decoder: func(kv *KeyValue, v map[string]interface{}) error { if codec := encoding.GetCodec(kv.Format); codec != nil { @@ -66,42 +63,73 @@ func TestReader_Value(t *testing.T) { } return fmt.Errorf("unsupported key: %s format: %s", kv.Key, kv.Format) }, + resolver: defaultResolver, } - r := newReader(opts) - err = r.Merge(&KeyValue{ - Key: "b", - Value: []byte(`{"a": {"b": {"X": 1, "Y": "lol", "z": true}}}`), - Format: "json", - }) - assert.NoError(t, err) - vv, ok := r.Value("a.b.X") - assert.True(t, ok) - vvv, err := vv.Int() - assert.NoError(t, err) - assert.Equal(t, int64(1), vvv) - assert.NoError(t, err) - vv, ok = r.Value("a.b.Y") - assert.True(t, ok) - vvy, err := vv.String() - assert.NoError(t, err) - assert.Equal(t, "lol", vvy) + ymlval := ` +a: + b: + X: 1 + Y: "lol" + z: true +` + tests := []struct { + name string + kv KeyValue + }{ + { + name: "json value", + kv: KeyValue{ + Key: "config", + Value: []byte(`{"a": {"b": {"X": 1, "Y": "lol", "z": true}}}`), + Format: "json", + }, + }, + { + name: "yaml value", + kv: KeyValue{ + Key: "config", + Value: []byte(ymlval), + Format: "yaml", + }, + }, + } - assert.NoError(t, err) - vv, ok = r.Value("a.b.z") - assert.True(t, ok) - vvz, err := vv.Bool() - assert.NoError(t, err) - assert.Equal(t, true, vvz) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReader(opts) + err := r.Merge(&test.kv) + assert.NoError(t, err) + vv, ok := r.Value("a.b.X") + assert.True(t, ok) + vvv, err := vv.Int() + assert.NoError(t, err) + assert.Equal(t, int64(1), vvv) - vv, ok = r.Value("aasasdg=234l.asdfk,") - assert.False(t, ok) + assert.NoError(t, err) + vv, ok = r.Value("a.b.Y") + assert.True(t, ok) + vvy, err := vv.String() + assert.NoError(t, err) + assert.Equal(t, "lol", vvy) - vv, ok = r.Value("aas......asdg=234l.asdfk,") - assert.False(t, ok) + assert.NoError(t, err) + vv, ok = r.Value("a.b.z") + assert.True(t, ok) + vvz, err := vv.Bool() + assert.NoError(t, err) + assert.Equal(t, true, vvz) - vv, ok = r.Value("a.b.Y.") - assert.False(t, ok) + vv, ok = r.Value("aasasdg=234l.asdfk,") + assert.False(t, ok) + + vv, ok = r.Value("aas......asdg=234l.asdfk,") + assert.False(t, ok) + + vv, ok = r.Value("a.b.Y.") + assert.False(t, ok) + }) + } } func TestReader_Source(t *testing.T) { @@ -115,6 +143,7 @@ func TestReader_Source(t *testing.T) { } return fmt.Errorf("unsupported key: %s format: %s", kv.Key, kv.Format) }, + resolver: defaultResolver, } r := newReader(opts) err = r.Merge(&KeyValue{