Skip to content

Commit

Permalink
fix: only load config and rules once (#1470)
Browse files Browse the repository at this point in the history
## Which problem is this PR solving?
 
We found an issue where refinery, as part of its config load process on
startup or reload, would load the config/rules (whether from file or
http or any), multiple times. This was happening because we the config
data was passed around as a reader and if we ever needed the data a
second time (which we do for validation vs unmarshaling into config) we
had to go get another reader.

## Short description of the changes

- updates `newFileConfig` to first grab the data of all config locations
and the use this data for all validation and unmarshalling. This
prevents `newFileConfig` from needing to grab the data multiple times
- updated some unit tests, but most needed unchanged.

---------

Co-authored-by: Yingrong Zhao <22300958+VinozzZ@users.noreply.github.com>
  • Loading branch information
TylerHelmuth and VinozzZ authored Jan 23, 2025
1 parent 70ca257 commit 570b6c2
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 74 deletions.
132 changes: 70 additions & 62 deletions config/configLoadHelpers.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -57,8 +56,8 @@ func formatFromResponse(resp *http.Response) Format {
}
}

// getReaderFor returns an io.ReadCloser for the given URL or filename.
func getReaderFor(u string) (io.ReadCloser, Format, error) {
// getBytesFor returns an []byte for the given URL or filename.
func getBytesFor(u string) ([]byte, Format, error) {
if u == "" {
return nil, FormatUnknown, fmt.Errorf("empty url")
}
Expand All @@ -68,7 +67,7 @@ func getReaderFor(u string) (io.ReadCloser, Format, error) {
}
switch uu.Scheme {
case "file", "": // we treat an empty scheme as a filename
r, err := os.Open(uu.Path)
r, err := os.ReadFile(uu.Path)
if err != nil {
return nil, FormatUnknown, err
}
Expand Down Expand Up @@ -102,76 +101,87 @@ func getReaderFor(u string) (io.ReadCloser, Format, error) {
if format == FormatUnknown {
format = formatFromFilename(uu.Path)
}
return resp.Body, format, nil
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, FormatUnknown, err
}
return body, format, nil
default:
return nil, FormatUnknown, fmt.Errorf("unknown scheme %q", uu.Scheme)
}
}

func load(r io.Reader, format Format, into any) error {
func load(data []byte, format Format, into any) error {
switch format {
case FormatYAML:
decoder := yaml.NewDecoder(r)
err := decoder.Decode(into)
err := yaml.Unmarshal(data, into)
return err
case FormatTOML:
decoder := toml.NewDecoder(r)
err := decoder.Decode(into)
err := toml.Unmarshal(data, into)
return err
case FormatJSON:
decoder := json.NewDecoder(r)
err := decoder.Decode(into)
err := json.Unmarshal(data, into)
return err
default:
return fmt.Errorf("unable to determine data format")
}
}

type configData struct {
data []byte
format Format
location string
}

// getConfigDataForLocations returns a slice of configData grabbed from each location.
func getConfigDataForLocations(locations []string) ([]configData, error) {
results := make([]configData, len(locations))
for i, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
data, format, err := getBytesFor(location)
if err != nil {
return nil, err
}
results[i] = configData{
data: data,
format: format,
location: location,
}
}
return results, nil
}

// This loads all the named configs into destination in the order they are listed.
// It returns the MD5 hash of the collected configs as a string (if there's only one
// config, this is the hash of that config; if there are multiple, it's the hash of
// all of them concatenated together).
func loadConfigsInto(dest any, locations []string) (string, error) {
// start a hash of the configs we read
func loadConfigsInto(dest any, configs []configData) (string, error) {
// start a hash of the configs we process
h := md5.New()
for _, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
r, format, err := getReaderFor(location)
if err != nil {
return "", err
}
defer r.Close()
// write the data to the hash as we read it
rdr := io.TeeReader(r, h)
for _, c := range configs {
// write the data to the hash
h.Write(c.data)

// when working on a struct, load only overwrites destination values that are
// explicitly named. So we can just keep loading successive files into
// explicitly named. So we can just keep loading successive sources into
// the same object without losing data we've already specified.
if err := load(rdr, format, dest); err != nil {
return "", fmt.Errorf("loadConfigsInto unable to load config %s: %w", location, err)
if err := load(c.data, c.format, dest); err != nil {
return "", fmt.Errorf("loadConfigsInto unable to load config %s: %w", c.location, err)
}
}
hash := hex.EncodeToString(h.Sum(nil))
return hash, nil
}

func loadConfigsIntoMap(dest map[string]any, locations []string) error {
for _, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
r, format, err := getReaderFor(location)
if err != nil {
return err
}
defer r.Close()

func loadConfigsIntoMap(dest map[string]any, configs []configData) error {
for _, c := range configs {
// when working on a map, when loading a nested object, load will overwrite the entire destination
// value, so we can't just keep loading successive files into the same object. Instead, we
// need to load into a new object and then merge it into the map.
temp := make(map[string]any)
if err := load(r, format, &temp); err != nil {
return fmt.Errorf("loadConfigsInto unable to load config %s: %w", location, err)
if err := load(c.data, c.format, &temp); err != nil {
return fmt.Errorf("loadConfigsInto unable to load config %s: %w", c.location, err)
}
for k, v := range temp {
switch vm := v.(type) {
Expand All @@ -196,13 +206,13 @@ func loadConfigsIntoMap(dest map[string]any, locations []string) error {
return nil
}

// validateConfigs reads the configs from the given location and validates them.
// validateConfigs gets the configs from the given location and validates them.
// It returns a list of failures; if the list is empty, the config is valid.
// err is non-nil only for significant errors like a missing file.
func validateConfigs(opts *CmdEnv) ([]string, error) {
// first read the configs into a map so we can validate them
func validateConfigs(configs []configData, opts *CmdEnv) ([]string, error) {
// first process the configs into a map so we can validate them
userData := make(map[string]any)
err := loadConfigsIntoMap(userData, opts.ConfigLocations)
err := loadConfigsIntoMap(userData, configs)
if err != nil {
return nil, err
}
Expand All @@ -220,19 +230,19 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
// Basic validation worked. Now we need to reload everything into our struct so that
// we can apply defaults and options, and then validate a second time.
var config configContents
_, err = loadConfigsInto(&config, opts.ConfigLocations)
_, err = loadConfigsInto(&config, configs)
if err != nil {
return nil, err
}

// apply defaults and options
if err := defaults.Set(&config); err != nil {
return nil, fmt.Errorf("readConfigInto unable to apply defaults: %w", err)
return nil, fmt.Errorf("loadConfigsInto unable to apply defaults: %w", err)
}

// apply command line options
if err := opts.ApplyTags(reflect.ValueOf(&config)); err != nil {
return nil, fmt.Errorf("readConfigInto unable to apply command line options: %w", err)
return nil, fmt.Errorf("loadConfigsInto unable to apply command line options: %w", err)
}

// possibly inject some keys to keep the validator happy
Expand All @@ -249,17 +259,15 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
config.OTelTracing.APIKey = "InvalidHoneycombAPIKey"
}

// The validator needs a map[string]any to work with, so we need to
// write it out to a buffer (we always use YAML) and then reload it.
buf := new(bytes.Buffer)
encoder := yaml.NewEncoder(buf)
encoder.SetIndent(2)
if err := encoder.Encode(config); err != nil {
return nil, fmt.Errorf("readConfigInto unable to reencode config: %w", err)
// The validator needs a map[string]any to work with, so we marshal to
// yaml bytes for an easy conversion to map[string]any.
data, err := yaml.Marshal(config)
if err != nil {
return nil, fmt.Errorf("loadConfigsInto unable to remarshal config: %w", err)
}

var rewrittenUserData map[string]any
if err := load(buf, FormatYAML, &rewrittenUserData); err != nil {
if err := load(data, FormatYAML, &rewrittenUserData); err != nil {
return nil, fmt.Errorf("validateConfig unable to reload hydrated config from buffer: %w", err)
}

Expand All @@ -268,10 +276,10 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
return failures, nil
}

func validateRules(locations []string) ([]string, error) {
// first read the configs into a map so we can validate them
func validateRules(configs []configData) ([]string, error) {
// first process the configs into a map so we can validate them
userData := make(map[string]any)
err := loadConfigsIntoMap(userData, locations)
err := loadConfigsIntoMap(userData, configs)
if err != nil {
return nil, err
}
Expand All @@ -285,9 +293,9 @@ func validateRules(locations []string) ([]string, error) {
return failures, nil
}

// readConfigInto reads the config from the given location and applies it to the given struct.
func readConfigInto(dest any, locations []string, opts *CmdEnv) (string, error) {
hash, err := loadConfigsInto(dest, locations)
// applyConfigInto applies the given configs to the given struct.
func applyConfigInto(dest any, configs []configData, opts *CmdEnv) (string, error) {
hash, err := loadConfigsInto(dest, configs)
if err != nil {
return hash, err
}
Expand All @@ -299,12 +307,12 @@ func readConfigInto(dest any, locations []string, opts *CmdEnv) (string, error)

// now we've got the config, apply defaults to zero values
if err := defaults.Set(dest); err != nil {
return hash, fmt.Errorf("readConfigInto unable to apply defaults: %w", err)
return hash, fmt.Errorf("applyConfigInto unable to apply defaults: %w", err)
}

// apply command line options
if err := opts.ApplyTags(reflect.ValueOf(dest)); err != nil {
return hash, fmt.Errorf("readConfigInto unable to apply command line options: %w", err)
return hash, fmt.Errorf("applyConfigInto unable to apply command line options: %w", err)
}

return hash, nil
Expand Down
17 changes: 11 additions & 6 deletions config/configLoadHelpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func Test_loadDuration(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := load(strings.NewReader(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
if err := load([]byte(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
t.Errorf("load() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.into, tt.want) {
Expand Down Expand Up @@ -123,7 +123,7 @@ func Test_loadMemsize(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := load(strings.NewReader(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
if err := load([]byte(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
t.Errorf("load() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.into, tt.want) {
Expand Down Expand Up @@ -198,9 +198,10 @@ func Test_loadConfigsInto(t *testing.T) {
cm1 := makeYAML("General.ConfigurationVersion", 2, "General.ConfigReloadInterval", Duration(1*time.Second), "Network.ListenAddr", "0.0.0.0:8080")
cm2 := makeYAML("General.ConfigReloadInterval", Duration(2*time.Second), "General.DatasetPrefix", "hello")
cfgfiles := createTempConfigs(t, cm1, cm2)

configs, err := getConfigDataForLocations(cfgfiles)
require.NoError(t, err)
cfg := configContents{}
hash, err := loadConfigsInto(&cfg, cfgfiles)
hash, err := loadConfigsInto(&cfg, configs)
require.NoError(t, err)
require.Equal(t, "2381a6563085f50ac56663b67ca85299", hash)
require.Equal(t, 2, cfg.General.ConfigurationVersion)
Expand All @@ -213,9 +214,11 @@ func Test_loadConfigsIntoMap(t *testing.T) {
cm1 := makeYAML("General.ConfigurationVersion", 2, "General.ConfigReloadInterval", Duration(1*time.Second), "Network.ListenAddr", "0.0.0.0:8080")
cm2 := makeYAML("General.ConfigReloadInterval", Duration(2*time.Second), "General.DatasetPrefix", "hello")
cfgfiles := createTempConfigs(t, cm1, cm2)
configs, err := getConfigDataForLocations(cfgfiles)
require.NoError(t, err)

cfg := map[string]any{}
err := loadConfigsIntoMap(cfg, cfgfiles)
err = loadConfigsIntoMap(cfg, configs)
require.NoError(t, err)
gen := cfg["General"].(map[string]any)
require.Equal(t, 2, gen["ConfigurationVersion"])
Expand Down Expand Up @@ -262,7 +265,9 @@ func Test_validateConfigs(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
cfgfiles := createTempConfigs(t, tt.cfgs...)
opts := &CmdEnv{ConfigLocations: cfgfiles}
got, err := validateConfigs(opts)
configs, err := getConfigDataForLocations(cfgfiles)
require.NoError(t, err)
got, err := validateConfigs(configs, opts)
if (err != nil) != tt.wantErr {
t.Errorf("validateConfigs() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
21 changes: 15 additions & 6 deletions config/file_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,17 +450,26 @@ func (e *FileConfigError) Error() string {
// newFileConfig does the work of creating and loading the start of a config object
// from the given arguments.
// It's used by both the main init as well as the reload code.
// In order to do proper validation, we actually read the file twice -- once into
// a map, and once into the actual config object.
// In order to do proper validation, we actually process the data twice -- once as
// a map, and once as the actual config object.
func newFileConfig(opts *CmdEnv) (*fileConfig, error) {
configData, err := getConfigDataForLocations(opts.ConfigLocations)
if err != nil {
return nil, err
}
rulesData, err := getConfigDataForLocations(opts.RulesLocations)
if err != nil {
return nil, err
}

// If we're not validating, skip this part
if !opts.NoValidate {
cfgFails, err := validateConfigs(opts)
cfgFails, err := validateConfigs(configData, opts)
if err != nil {
return nil, err
}

ruleFails, err := validateRules(opts.RulesLocations)
ruleFails, err := validateRules(rulesData)
if err != nil {
return nil, err
}
Expand All @@ -477,13 +486,13 @@ func newFileConfig(opts *CmdEnv) (*fileConfig, error) {

// Now load the files
mainconf := &configContents{}
mainhash, err := readConfigInto(mainconf, opts.ConfigLocations, opts)
mainhash, err := applyConfigInto(mainconf, configData, opts)
if err != nil {
return nil, err
}

var rulesconf *V2SamplerConfig
ruleshash, err := readConfigInto(&rulesconf, opts.RulesLocations, nil)
ruleshash, err := applyConfigInto(&rulesconf, rulesData, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 570b6c2

Please sign in to comment.