fix(config): replace text with "${}" only (#1375)

* fix(config): replace text with "${}" only
pull/1377/head
包子 3 years ago committed by GitHub
parent eaf0ceab0c
commit 925e55a04d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      config/config_test.go
  2. 2
      config/env/env_test.go
  3. 22
      config/options.go
  4. 10
      internal/httputil/http_test.go

@ -116,7 +116,6 @@ func TestConfig(t *testing.T) {
httpAddr = "0.0.0.0" httpAddr = "0.0.0.0"
httpTimeout = 0.5 httpTimeout = 0.5
grpcPort = 10080 grpcPort = 10080
enableSSL = true
endpoint1 = "www.aaa.com" endpoint1 = "www.aaa.com"
databaseDriver = "mysql" databaseDriver = "mysql"
) )
@ -160,7 +159,7 @@ func TestConfig(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, httpAddr, testConf.Server.Http.Addr) assert.Equal(t, httpAddr, testConf.Server.Http.Addr)
assert.Equal(t, httpTimeout, testConf.Server.Http.Timeout) assert.Equal(t, httpTimeout, testConf.Server.Http.Timeout)
assert.Equal(t, enableSSL, testConf.Server.Http.EnableSSL) assert.Equal(t, true, testConf.Server.Http.EnableSSL)
assert.Equal(t, grpcPort, testConf.Server.GRpc.Port) assert.Equal(t, grpcPort, testConf.Server.GRpc.Port)
assert.Equal(t, endpoint1, testConf.Endpoints[0]) assert.Equal(t, endpoint1, testConf.Endpoints[0])
assert.Equal(t, 2, len(testConf.Endpoints)) assert.Equal(t, 2, len(testConf.Endpoints))
@ -170,7 +169,6 @@ func TestDefaultResolver(t *testing.T) {
var ( var (
portString = "8080" portString = "8080"
countInt = 10 countInt = 10
enableBool = true
rateFloat = 0.9 rateFloat = 0.9
) )
@ -226,7 +224,7 @@ func TestDefaultResolver(t *testing.T) {
{ {
name: "test bool with default", name: "test bool with default",
path: "foo.bar.enable", path: "foo.bar.enable",
expect: enableBool, expect: true,
}, },
{ {
name: "test float without default", name: "test float without default",
@ -253,16 +251,6 @@ func TestDefaultResolver(t *testing.T) {
path: "foo.bar.value1", path: "foo.bar.value1",
expect: "foobar", expect: "foobar",
}, },
{
name: "test $value",
path: "foo.bar.value2",
expect: portString,
},
{
name: "test $value:default",
path: "foo.bar.value3",
expect: portString + ":default",
},
} }
for _, test := range tests { for _, test := range tests {

@ -16,7 +16,7 @@ const _testJSON = `
{ {
"test":{ "test":{
"server":{ "server":{
"name":"$SERVICE_NAME", "name":"${SERVICE_NAME}",
"addr":"${ADDR:127.0.0.1}", "addr":"${ADDR:127.0.0.1}",
"port":"${PORT:8080}" "port":"${PORT:8080}"
} }

@ -2,7 +2,7 @@ package config
import ( import (
"fmt" "fmt"
"os" "regexp"
"strings" "strings"
"github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding"
@ -81,7 +81,7 @@ func defaultDecoder(src *KeyValue, target map[string]interface{}) error {
} }
// defaultResolver resolve placeholder in map value, // defaultResolver resolve placeholder in map value,
// placeholder format in ${key:default} or $key. // placeholder format in ${key:default}.
func defaultResolver(input map[string]interface{}) error { func defaultResolver(input map[string]interface{}) error {
mapper := func(name string) string { mapper := func(name string) string {
args := strings.SplitN(strings.TrimSpace(name), ":", 2) args := strings.SplitN(strings.TrimSpace(name), ":", 2)
@ -99,7 +99,7 @@ func defaultResolver(input map[string]interface{}) error {
for k, v := range sub { for k, v := range sub {
switch vt := v.(type) { switch vt := v.(type) {
case string: case string:
sub[k] = os.Expand(vt, mapper) sub[k] = expand(vt, mapper)
case map[string]interface{}: case map[string]interface{}:
if err := resolve(vt); err != nil { if err := resolve(vt); err != nil {
return err return err
@ -108,7 +108,7 @@ func defaultResolver(input map[string]interface{}) error {
for i, iface := range vt { for i, iface := range vt {
switch it := iface.(type) { switch it := iface.(type) {
case string: case string:
vt[i] = os.Expand(it, mapper) vt[i] = expand(it, mapper)
case map[string]interface{}: case map[string]interface{}:
if err := resolve(it); err != nil { if err := resolve(it); err != nil {
return err return err
@ -122,3 +122,17 @@ func defaultResolver(input map[string]interface{}) error {
} }
return resolve(input) return resolve(input)
} }
func expand(s string, mapping func(string) string) string {
r, err := regexp.Compile(`\${(.*?)}`)
if err != nil {
return s
}
re := r.FindAllStringSubmatch(s, -1)
for _, i := range re {
if len(i) == 2 {
s = strings.ReplaceAll(s, i[0], mapping(i[1]))
}
}
return s
}

@ -97,12 +97,12 @@ func TestStatusFromGRPCCode(t *testing.T) {
func TestContentType(t *testing.T) { func TestContentType(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
subtype string subtype string
want string want string
}{ }{
{"kratos","kratos","application/kratos"}, {"kratos", "kratos", "application/kratos"},
{"json","json","application/json"}, {"json", "json", "application/json"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -111,4 +111,4 @@ func TestContentType(t *testing.T) {
} }
}) })
} }
} }

Loading…
Cancel
Save