Skip to content

Commit c6708a3

Browse files
author
majianhan
committed
fix: support string duration values in JSON/YAML config files
time.Duration fields like watch-progress-notify-interval do not properly unmarshal from JSON strings (e.g. "1m"). This is because time.Duration is just an int64 and does not implement json.Unmarshaler. Add preprocessDurationFields() that converts string duration values to nanosecond integers before YAML/JSON unmarshaling. This applies to all duration config fields, matching the behavior of command-line flags which use time.ParseDuration. Fixes #20342 Signed-off-by: majianhan <majianhan@kylinos.cn>
1 parent 0f5cee1 commit c6708a3

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

server/embed/config.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package embed
1616

1717
import (
1818
"crypto/tls"
19+
"encoding/json"
1920
"errors"
2021
"flag"
2122
"fmt"
@@ -771,12 +772,74 @@ func ConfigFromFile(path string) (*Config, error) {
771772
return &cfg.Config, nil
772773
}
773774

775+
// durationFieldKeys lists all JSON keys in Config that correspond to time.Duration fields.
776+
// String values (e.g. "1m", "500ms") are converted to nanosecond integers before unmarshaling,
777+
// because time.Duration does not implement json.Unmarshaler.
778+
var durationFieldKeys = map[string]bool{
779+
"backend-batch-interval": true,
780+
"grpc-keepalive-min-time": true,
781+
"grpc-keepalive-interval": true,
782+
"grpc-keepalive-timeout": true,
783+
"corrupt-check-time": true,
784+
"compact-hash-check-time": true,
785+
"compaction-sleep-interval": true,
786+
"watch-progress-notify-interval": true,
787+
"warning-apply-duration": true,
788+
"warning-unary-request-duration": true,
789+
"downgrade-check-time": true,
790+
}
791+
792+
// preprocessDurationFields converts string duration values (e.g. "1m", "500ms") to
793+
// nanosecond integers so that time.Duration fields unmarshal correctly from JSON/YAML.
794+
func preprocessDurationFields(b []byte) ([]byte, error) {
795+
var raw map[string]json.RawMessage
796+
if err := yaml.Unmarshal(b, &raw); err != nil {
797+
// If parsing as a map fails, return the original bytes and let the
798+
// caller handle the error during normal unmarshaling.
799+
return b, nil
800+
}
801+
802+
modified := false
803+
for key, val := range raw {
804+
if !durationFieldKeys[key] {
805+
continue
806+
}
807+
// Try to unmarshal as a string (e.g. "1m", "10s").
808+
var s string
809+
if err := json.Unmarshal(val, &s); err != nil {
810+
// Not a string; might already be a number, which is fine.
811+
continue
812+
}
813+
d, err := time.ParseDuration(s)
814+
if err != nil {
815+
return nil, fmt.Errorf("invalid duration value for %q: %w", key, err)
816+
}
817+
nsBytes, err := json.Marshal(d.Nanoseconds())
818+
if err != nil {
819+
return nil, err
820+
}
821+
raw[key] = nsBytes
822+
modified = true
823+
}
824+
825+
if !modified {
826+
return b, nil
827+
}
828+
829+
return yaml.Marshal(raw)
830+
}
831+
774832
func (cfg *configYAML) configFromFile(path string) error {
775833
b, err := os.ReadFile(path)
776834
if err != nil {
777835
return err
778836
}
779837

838+
b, err = preprocessDurationFields(b)
839+
if err != nil {
840+
return err
841+
}
842+
780843
defaultInitialCluster := cfg.InitialCluster
781844

782845
err = yaml.Unmarshal(b, cfg)

server/embed/config_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package embed
1616

1717
import (
1818
"crypto/tls"
19+
"encoding/json"
1920
"errors"
2021
"flag"
2122
"fmt"
@@ -944,3 +945,121 @@ func TestFastLeaseKeepAliveValidate(t *testing.T) {
944945
})
945946
}
946947
}
948+
949+
func TestConfigFileDurationFields(t *testing.T) {
950+
testCases := []struct {
951+
name string
952+
config map[string]any
953+
expectErr bool
954+
check func(t *testing.T, cfg *Config)
955+
}{
956+
{
957+
name: "string duration for watch-progress-notify-interval",
958+
config: map[string]any{
959+
"watch-progress-notify-interval": "1m",
960+
},
961+
check: func(t *testing.T, cfg *Config) {
962+
require.Equal(t, time.Minute, cfg.WatchProgressNotifyInterval)
963+
},
964+
},
965+
{
966+
name: "numeric duration (nanoseconds) for watch-progress-notify-interval",
967+
config: map[string]any{
968+
"watch-progress-notify-interval": float64(time.Minute),
969+
},
970+
check: func(t *testing.T, cfg *Config) {
971+
require.Equal(t, time.Minute, cfg.WatchProgressNotifyInterval)
972+
},
973+
},
974+
{
975+
name: "string durations for multiple fields",
976+
config: map[string]any{
977+
"watch-progress-notify-interval": "30s",
978+
"backend-batch-interval": "500ms",
979+
"grpc-keepalive-min-time": "5s",
980+
"grpc-keepalive-interval": "2h",
981+
"grpc-keepalive-timeout": "20s",
982+
"corrupt-check-time": "4m",
983+
"compact-hash-check-time": "2m",
984+
"compaction-sleep-interval": "100ms",
985+
"warning-apply-duration": "200ms",
986+
"warning-unary-request-duration": "300ms",
987+
"downgrade-check-time": "5s",
988+
},
989+
check: func(t *testing.T, cfg *Config) {
990+
require.Equal(t, 30*time.Second, cfg.WatchProgressNotifyInterval)
991+
require.Equal(t, 500*time.Millisecond, cfg.BackendBatchInterval)
992+
require.Equal(t, 5*time.Second, cfg.GRPCKeepAliveMinTime)
993+
require.Equal(t, 2*time.Hour, cfg.GRPCKeepAliveInterval)
994+
require.Equal(t, 20*time.Second, cfg.GRPCKeepAliveTimeout)
995+
require.Equal(t, 4*time.Minute, cfg.CorruptCheckTime)
996+
require.Equal(t, 2*time.Minute, cfg.CompactHashCheckTime)
997+
require.Equal(t, 100*time.Millisecond, cfg.CompactionSleepInterval)
998+
require.Equal(t, 200*time.Millisecond, cfg.WarningApplyDuration)
999+
require.Equal(t, 300*time.Millisecond, cfg.WarningUnaryRequestDuration)
1000+
require.Equal(t, 5*time.Second, cfg.DowngradeCheckTime)
1001+
},
1002+
},
1003+
{
1004+
name: "invalid duration string",
1005+
config: map[string]any{
1006+
"watch-progress-notify-interval": "not-a-duration",
1007+
},
1008+
expectErr: true,
1009+
},
1010+
}
1011+
for _, tc := range testCases {
1012+
t.Run(tc.name, func(t *testing.T) {
1013+
b, err := json.Marshal(tc.config)
1014+
require.NoError(t, err)
1015+
1016+
tmpfile := mustCreateCfgFile(t, b)
1017+
defer os.Remove(tmpfile.Name())
1018+
1019+
cfg, err := ConfigFromFile(tmpfile.Name())
1020+
if tc.expectErr {
1021+
require.Error(t, err)
1022+
return
1023+
}
1024+
require.NoError(t, err)
1025+
tc.check(t, cfg)
1026+
})
1027+
}
1028+
}
1029+
1030+
func TestPreprocessDurationFields(t *testing.T) {
1031+
testCases := []struct {
1032+
name string
1033+
input string
1034+
expectErr bool
1035+
}{
1036+
{
1037+
name: "string duration value",
1038+
input: `{"watch-progress-notify-interval": "1m"}`,
1039+
},
1040+
{
1041+
name: "numeric duration value passes through",
1042+
input: `{"watch-progress-notify-interval": 60000000000}`,
1043+
},
1044+
{
1045+
name: "non-duration field unchanged",
1046+
input: `{"name": "my-etcd"}`,
1047+
},
1048+
{
1049+
name: "invalid duration string",
1050+
input: `{"watch-progress-notify-interval": "invalid"}`,
1051+
expectErr: true,
1052+
},
1053+
}
1054+
for _, tc := range testCases {
1055+
t.Run(tc.name, func(t *testing.T) {
1056+
result, err := preprocessDurationFields([]byte(tc.input))
1057+
if tc.expectErr {
1058+
require.Error(t, err)
1059+
return
1060+
}
1061+
require.NoError(t, err)
1062+
require.NotEmpty(t, result)
1063+
})
1064+
}
1065+
}

0 commit comments

Comments
 (0)