Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor configuration management to use a .conf file #1292

Merged
merged 16 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
544 changes: 206 additions & 338 deletions config/config.go

Large diffs are not rendered by default.

243 changes: 243 additions & 0 deletions config/parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
package config

import (
"bufio"
"fmt"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
"time"
)

// ConfigParser handles the parsing of configuration files
type ConfigParser struct {
// store holds the raw key-value pairs from the config file
store map[string]string
}

// NewConfigParser creates a new instance of ConfigParser
func NewConfigParser() *ConfigParser {
return &ConfigParser{
store: make(map[string]string),
}
}

// ParseFromFile reads the configuration data from a file
func (p *ConfigParser) ParseFromFile(filename string) error {
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("error opening config file: %w", err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
return processConfigData(scanner, p)
}

// ParseFromStdin reads the configuration data from stdin
func (p *ConfigParser) ParseFromStdin() error {
scanner := bufio.NewScanner(os.Stdin)
return processConfigData(scanner, p)
}

// ParseDefaults populates a struct with default values based on struct tag `default`
func (p *ConfigParser) ParseDefaults(cfg interface{}) error {
val := reflect.ValueOf(cfg)
if val.Kind() != reflect.Ptr || val.IsNil() {
return fmt.Errorf("config must be a non-nil pointer to a struct")
}

val = val.Elem()
if val.Kind() != reflect.Struct {
return fmt.Errorf("config must be a pointer to a struct")
}

return p.unmarshalStruct(val, "")
}

// Loadconfig populates a struct with configuration values based on struct tags
func (p *ConfigParser) Loadconfig(cfg interface{}) error {
val := reflect.ValueOf(cfg)
if val.Kind() != reflect.Ptr || val.IsNil() {
return fmt.Errorf("config must be a non-nil pointer to a struct")
}

val = val.Elem()
if val.Kind() != reflect.Struct {
return fmt.Errorf("config must be a pointer to a struct")
}

if err := p.unmarshalStruct(val, ""); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}

if err := validateConfig(DiceConfig); err != nil {
return fmt.Errorf("failed to validate config: %w", err)
}

return nil
}

// processConfigData reads the configuration data line by line and stores it in the ConfigParser
func processConfigData(scanner *bufio.Scanner, p *ConfigParser) error {
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}

parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
slog.Warn("invalid config line", slog.String("line", line))
continue
}

key := strings.TrimSpace(parts[0])
value := strings.Trim(strings.TrimSpace(parts[1]), "\"")
p.store[key] = value
}

return scanner.Err()
}

// unmarshalStruct handles the recursive struct parsing.
func (p *ConfigParser) unmarshalStruct(val reflect.Value, prefix string) error {
typ := val.Type()

for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)

// Skip unexported fields just like how encoding/json does
if !field.CanSet() {
continue
}

// Get config key or field name
key := fieldType.Tag.Get("config")

// Use field name as key if not specified in tag
if key == "" {
key = strings.ToLower(fieldType.Name)
}

// Skip fields with "-" tag
if key == "-" {
continue
}

// Apply nested struct's tag as prefix
fullKey := key
if prefix != "" {
fullKey = fmt.Sprintf("%s.%s", prefix, key)
}

// Recursively process nested structs with their prefix
if field.Kind() == reflect.Struct {
if err := p.unmarshalStruct(field, fullKey); err != nil {
return err
}
continue
}

// Fetch and set value for non-struct fields
value, exists := p.store[fullKey]
if !exists {
// Use default value from tag if available
if defaultValue := fieldType.Tag.Get("default"); defaultValue != "" {
value = defaultValue
} else {
continue
}
}

if err := setField(field, value); err != nil {
return fmt.Errorf("error setting field %s: %w", fullKey, err)
}
}

return nil
}

// setField sets the appropriate field value based on its type
func setField(field reflect.Value, value string) error {
switch field.Kind() {
case reflect.String:
field.SetString(value)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if field.Type() == reflect.TypeOf(time.Duration(0)) {
// Handle time.Duration type
duration, err := parseDuration(value)
if err != nil {
return fmt.Errorf("failed to parse duration: %w", err)
}
field.Set(reflect.ValueOf(duration))
} else {
// Handle other integer types
val, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse integer: %w", err)
}
field.SetInt(val)
}

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
val, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse unsigned integer: %w", err)
}
field.SetUint(val)

case reflect.Float32, reflect.Float64:
val, err := strconv.ParseFloat(value, 64)
if err != nil {
return fmt.Errorf("failed to parse float: %w", err)
}
field.SetFloat(val)

case reflect.Bool:
val, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("failed to parse boolean: %w", err)
}
field.SetBool(val)

case reflect.Slice:
// Handle slices of basic types
elemType := field.Type().Elem()
values := strings.Split(value, ",")
slice := reflect.MakeSlice(field.Type(), len(values), len(values))
for i, v := range values {
elem := slice.Index(i)
elemVal := reflect.New(elemType).Elem()
if err := setField(elemVal, strings.TrimSpace(v)); err != nil {
return fmt.Errorf("failed to parse slice element at index %d: %w", i, err)
}
elem.Set(elemVal)
}
field.Set(slice)

default:
return fmt.Errorf("unsupported type: %s", field.Type())
}

return nil
}

// parseDuration handles parsing of time.Duration with proper validation.
func parseDuration(value string) (time.Duration, error) {
if value == "" {
return 0, fmt.Errorf("duration string is empty")
}
duration, err := time.ParseDuration(value)
if err != nil {
return 0, fmt.Errorf("invalid duration format: %s", value)
}
if duration <= 0 {
return 0, fmt.Errorf("duration must be positive, got: %s", value)
}
return duration, nil
}
97 changes: 97 additions & 0 deletions config/validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package config

import (
"fmt"
"log"
"reflect"
"strings"

"github.com/go-playground/validator/v10"
)

func validateConfig(config *Config) error {
validate := validator.New()
validate.RegisterStructValidation(validateShardCount, Config{})

if err := validate.Struct(config); err != nil {
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
return fmt.Errorf("unexpected validation error type: %v", err)
}

processedFields := make(map[string]bool)

for _, validationErr := range validationErrors {
fieldName := strings.TrimPrefix(validationErr.Namespace(), "Config.")

if processedFields[fieldName] {
continue
}
processedFields[fieldName] = true

log.Printf("Field %s failed validation: %s", fieldName, validationErr.Tag())

if err := applyDefaultValuesFromTags(config, fieldName); err != nil {
return fmt.Errorf("error setting default for %s: %v", fieldName, err)
}
}
}
return nil
}

func validateShardCount(sl validator.StructLevel) {
config := sl.Current().Interface().(Config)
if config.Performance.NumShards <= 0 && config.Performance.NumShards != -1 {
sl.ReportError(config.Performance.NumShards, "NumShards", "NumShards", "invalidValue", "must be -1 or greater than 0")
}
}

func applyDefaultValuesFromTags(config *Config, fieldName string) error {
configType := reflect.TypeOf(config).Elem()
configValue := reflect.ValueOf(config).Elem()

// Split the field name if it refers to a nested struct
parts := strings.Split(fieldName, ".")
var field reflect.StructField
var fieldValue reflect.Value
var found bool

// Traverse the struct to find the nested field
for i, part := range parts {
// If it's the first field, just look in the top-level struct
if i == 0 {
field, found = configType.FieldByName(part)
if !found {
log.Printf("Warning: %s field not found", part)
return fmt.Errorf("field %s not found in config struct", part)
}
fieldValue = configValue.FieldByName(part)
} else {
// Otherwise, the struct is nested, so navigate into it
if fieldValue.Kind() == reflect.Struct {
field, found = fieldValue.Type().FieldByName(part)
if !found {
log.Printf("Warning: %s field not found in %s", part, fieldValue.Type())
return fmt.Errorf("field %s not found in struct %s", part, fieldValue.Type())
}
fieldValue = fieldValue.FieldByName(part)
} else {
log.Printf("Warning: %s is not a struct", fieldName)
return fmt.Errorf("%s is not a struct", fieldName)
}
}
}

defaultValue := field.Tag.Get("default")
if defaultValue == "" {
log.Printf("Warning: %s field has no default value to set, leaving empty string", fieldName)
return nil
}

if err := setField(fieldValue, defaultValue); err != nil {
return fmt.Errorf("error setting default value for %s: %v", fieldName, err)
}

log.Printf("Setting default value for %s to: %s", fieldName, defaultValue)
return nil
}
53 changes: 0 additions & 53 deletions dice.toml

This file was deleted.

Loading
Loading