diff --git a/config/config_test.go b/config/config_test.go index f9bb5a7b7..f066aa353 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -24,10 +24,13 @@ func TestDefaultResolver(t *testing.T) { "enable": "${ENABLE:false}", "rate": "${RATE}", "empty": "${EMPTY:foobar}", - "array": []interface{}{"${PORT}", "${NOTEXIST:8081}"}, - "value1": "${test.value}", - "value2": "$PORT", - "value3": "$PORT:default", + "array": []interface{}{ + "${PORT}", + map[string]interface{}{"foobar": "${NOTEXIST:8081}"}, + }, + "value1": "${test.value}", + "value2": "$PORT", + "value3": "$PORT:default", }, }, "test": map[string]interface{}{ @@ -78,7 +81,7 @@ func TestDefaultResolver(t *testing.T) { { name: "test array", path: "foo.bar.array", - expect: []interface{}{portString, "8081"}, + expect: []interface{}{portString, map[string]interface{}{"foobar": "8081"}}, }, { name: "test ${test.value}", diff --git a/config/env/env.go b/config/env/env.go new file mode 100644 index 000000000..adf37c989 --- /dev/null +++ b/config/env/env.go @@ -0,0 +1,62 @@ +package env + +import ( + "os" + "strings" + + "github.com/go-kratos/kratos/v2/config" +) + +type env struct { + prefixs []string +} + +func NewSource(prefixs ...string) config.Source { + return &env{prefixs: prefixs} +} + +func (e *env) Load() (kv []*config.KeyValue, err error) { + for _, envstr := range os.Environ() { + var k, v string + subs := strings.SplitN(envstr, "=", 2) + k = subs[0] + if len(subs) > 1 { + v = subs[1] + } + + if len(e.prefixs) > 0 { + p, ok := matchPrefix(e.prefixs, envstr) + if !ok { + continue + } + // trim prefix + k = k[len(p):] + if k[0] == '_' { + k = k[1:] + } + } + + kv = append(kv, &config.KeyValue{ + Key: k, + Value: []byte(v), + }) + } + return +} + +func (e *env) Watch() (config.Watcher, error) { + w, err := NewWatcher() + if err != nil { + return nil, err + } + return w, nil +} + +func matchPrefix(prefixs []string, s string) (string, bool) { + for _, p := range prefixs { + if strings.HasPrefix(s, p) { + return p, true + } + } + return "", false +} diff --git a/config/env/env_test.go b/config/env/env_test.go new file mode 100644 index 000000000..49d8e1816 --- /dev/null +++ b/config/env/env_test.go @@ -0,0 +1,247 @@ +package env + +import ( + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/go-kratos/kratos/v2/config" + "github.com/go-kratos/kratos/v2/config/file" + "github.com/stretchr/testify/assert" +) + +const _testJSON = ` +{ + "test":{ + "server":{ + "name":"$SERVICE_NAME", + "addr":"${ADDR:127.0.0.1}", + "port":"${PORT:8080}" + } + }, + "foo":[ + { + "name":"Tom", + "age":"${AGE}" + } + ] +}` + +func TestEnvWithPrefix(t *testing.T) { + var ( + path = filepath.Join(os.TempDir(), "test_config") + filename = filepath.Join(path, "test.json") + data = []byte(_testJSON) + ) + defer os.Remove(path) + if err := os.MkdirAll(path, 0700); err != nil { + t.Error(err) + } + if err := ioutil.WriteFile(filename, data, 0666); err != nil { + t.Error(err) + } + + // set env + prefix1, prefix2 := "KRATOS_", "FOO" + envs := map[string]string{ + prefix1 + "SERVICE_NAME": "kratos_app", + prefix2 + "ADDR": "192.168.0.1", + prefix1 + "AGE": "20", + } + + for k, v := range envs { + if err := os.Setenv(k, v); err != nil { + t.Fatal(err) + } + } + + c := config.New(config.WithSource( + NewSource(prefix1, prefix2), + file.NewSource(path), + )) + + if err := c.Load(); err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + path string + expect interface{} + }{ + { + name: "test $KEY", + path: "test.server.name", + expect: "kratos_app", + }, + { + name: "test ${KEY:DEFAULT} without default", + path: "test.server.addr", + expect: "192.168.0.1", + }, + { + name: "test ${KEY:DEFAULT} with default", + path: "test.server.port", + expect: "8080", + }, + { + name: "test ${KEY} in array", + path: "foo", + expect: []interface{}{ + map[string]interface{}{ + "name": "Tom", + "age": "20", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var err error + v := c.Value(test.path) + if v.Load() != nil { + var actual interface{} + switch test.expect.(type) { + case int: + if actual, err = v.Int(); err == nil { + assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") + } + case string: + if actual, err = v.String(); err == nil { + assert.Equal(t, test.expect, actual, "string value should be equal") + } + case bool: + if actual, err = v.Bool(); err == nil { + assert.Equal(t, test.expect, actual, "bool value should be equal") + } + case float64: + if actual, err = v.Float(); err == nil { + assert.Equal(t, test.expect, actual, "float64 value should be equal") + } + default: + actual = v.Load() + if !reflect.DeepEqual(test.expect, actual) { + t.Logf("\nexpect: %#v\nactural: %#v", test.expect, actual) + t.Fail() + } + } + if err != nil { + t.Error(err) + } + } else { + t.Error("value path not found") + } + }) + } +} + +func TestEnvWithoutPrefix(t *testing.T) { + var ( + path = filepath.Join(os.TempDir(), "test_config") + filename = filepath.Join(path, "test.json") + data = []byte(_testJSON) + ) + defer os.Remove(path) + if err := os.MkdirAll(path, 0700); err != nil { + t.Error(err) + } + if err := ioutil.WriteFile(filename, data, 0666); err != nil { + t.Error(err) + } + + // set env + envs := map[string]string{ + "SERVICE_NAME": "kratos_app", + "ADDR": "192.168.0.1", + "AGE": "20", + } + + for k, v := range envs { + if err := os.Setenv(k, v); err != nil { + t.Fatal(err) + } + } + + c := config.New(config.WithSource( + NewSource(), + file.NewSource(path), + )) + + if err := c.Load(); err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + path string + expect interface{} + }{ + { + name: "test $KEY", + path: "test.server.name", + expect: "kratos_app", + }, + { + name: "test ${KEY:DEFAULT} without default", + path: "test.server.addr", + expect: "192.168.0.1", + }, + { + name: "test ${KEY:DEFAULT} with default", + path: "test.server.port", + expect: "8080", + }, + { + name: "test ${KEY} in array", + path: "foo", + expect: []interface{}{ + map[string]interface{}{ + "name": "Tom", + "age": "20", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var err error + v := c.Value(test.path) + if v.Load() != nil { + var actual interface{} + switch test.expect.(type) { + case int: + if actual, err = v.Int(); err == nil { + assert.Equal(t, test.expect, int(actual.(int64)), "int value should be equal") + } + case string: + if actual, err = v.String(); err == nil { + assert.Equal(t, test.expect, actual, "string value should be equal") + } + case bool: + if actual, err = v.Bool(); err == nil { + assert.Equal(t, test.expect, actual, "bool value should be equal") + } + case float64: + if actual, err = v.Float(); err == nil { + assert.Equal(t, test.expect, actual, "float64 value should be equal") + } + default: + actual = v.Load() + if !reflect.DeepEqual(test.expect, actual) { + t.Logf("\nexpect: %#v\nactural: %#v", test.expect, actual) + t.Fail() + } + } + if err != nil { + t.Error(err) + } + } else { + t.Error("value path not found") + } + }) + } +} diff --git a/config/env/watcher.go b/config/env/watcher.go new file mode 100644 index 000000000..15a3f8ce3 --- /dev/null +++ b/config/env/watcher.go @@ -0,0 +1,24 @@ +package env + +import ( + "github.com/go-kratos/kratos/v2/config" +) + +type watcher struct { + exit chan struct{} +} + +func NewWatcher() (config.Watcher, error) { + return &watcher{exit: make(chan struct{})}, nil +} + +// Next will be blocked until the Stop method is called +func (w *watcher) Next() ([]*config.KeyValue, error) { + <-w.exit + return nil, nil +} + +func (w *watcher) Stop() error { + close(w.exit) + return nil +} diff --git a/config/options.go b/config/options.go index 278698b34..de214c42b 100644 --- a/config/options.go +++ b/config/options.go @@ -92,8 +92,13 @@ func defaultResolver(input map[string]interface{}) error { } case []interface{}: for i, iface := range vt { - if s, ok := iface.(string); ok { - vt[i] = os.Expand(s, mapper) + switch it := iface.(type) { + case string: + vt[i] = os.Expand(it, mapper) + case map[string]interface{}: + if err := resolve(it); err != nil { + return err + } } } sub[k] = vt diff --git a/config/reader.go b/config/reader.go index 2dc778c3e..833508b40 100644 --- a/config/reader.go +++ b/config/reader.go @@ -1,6 +1,8 @@ package config import ( + "bytes" + "encoding/gob" "encoding/json" "fmt" "strings" @@ -59,15 +61,22 @@ func (r *reader) Source() ([]byte, error) { } func cloneMap(src map[string]interface{}) (map[string]interface{}, error) { - data, err := marshalJSON(src) + // https://gist.github.com/soroushjp/0ec92102641ddfc3ad5515ca76405f4d + var buf bytes.Buffer + gob.Register(map[string]interface{}{}) + gob.Register([]interface{}{}) + enc := gob.NewEncoder(&buf) + dec := gob.NewDecoder(&buf) + err := enc.Encode(src) if err != nil { return nil, err } - dst := make(map[string]interface{}) - if err = unmarshalJSON(data, &dst); err != nil { + var copy map[string]interface{} + err = dec.Decode(©) + if err != nil { return nil, err } - return dst, nil + return copy, nil } func convertMap(src interface{}) interface{} { diff --git a/config/value.go b/config/value.go index c087c3ed2..db122b09f 100644 --- a/config/value.go +++ b/config/value.go @@ -78,6 +78,8 @@ func (v *atomicValue) String() (string, error) { return val, nil case bool, int, int32, int64, float64: return fmt.Sprint(val), nil + case []byte: + return string(val), nil default: if s, ok := val.(fmt.Stringer); ok { return s.String(), nil