From 60bd0dc5b21a6d1d3eec0f660da3917aa3b2d8c8 Mon Sep 17 00:00:00 2001 From: Tanz0rz Date: Thu, 1 Feb 2024 16:08:36 -0700 Subject: [PATCH] Improving parameter validation logic and improving project code check tests (#11) --- .../{parameters.go => parameter_types.go} | 101 ++++++------------ .../client/validate/parameter_types_test.go | 62 +++++++++++ golang/project-validation/directories.go | 27 +++++ .../integration_check_cleanup_test.go | 65 +++++++++++ .../integration_test_validation/main.go | 84 --------------- ...parameter_validation_check_format_test.go} | 65 ++++++----- 6 files changed, 226 insertions(+), 178 deletions(-) rename golang/client/validate/{parameters.go => parameter_types.go} (63%) create mode 100644 golang/client/validate/parameter_types_test.go create mode 100644 golang/project-validation/directories.go create mode 100644 golang/project-validation/integration_check_cleanup_test.go delete mode 100644 golang/project-validation/integration_test_validation/main.go rename golang/project-validation/{check_parameter_validation_function_calls/main.go => parameter_validation_check_format_test.go} (55%) diff --git a/golang/client/validate/parameters.go b/golang/client/validate/parameter_types.go similarity index 63% rename from golang/client/validate/parameters.go rename to golang/client/validate/parameter_types.go index 760f8704..28c42a43 100644 --- a/golang/client/validate/parameters.go +++ b/golang/client/validate/parameter_types.go @@ -7,6 +7,7 @@ import ( type ParameterType int +// Generic parameter types const ( EthereumAddress ParameterType = iota EthereumAddressPointer @@ -14,9 +15,17 @@ const ( BigIntPointer ChainID PrivateKey - ApprovalType +) + +// Swap parameter types +const ( + ApprovalType ParameterType = iota + 50 Slippage - Page +) + +// Orderbook parameter types +const ( + Page ParameterType = iota + 100 PagePointer Limit LimitPointer @@ -29,168 +38,128 @@ const ( OrderHash ) -func GetParameterTypeName(parameterType ParameterType) string { - switch parameterType { - case EthereumAddress: - return "EthereumAddress" - case EthereumAddressPointer: - return "EthereumAddressPointer" - case BigInt: - return "BigInt" - case BigIntPointer: - return "BigIntPointer" - case ChainID: - return "ChainID" - case PrivateKey: - return "PrivateKey" - case ApprovalType: - return "ApprovalType" - case Slippage: - return "Slippage" - case Page: - return "Page" - case PagePointer: - return "PagePointer" - case Limit: - return "Limit" - case LimitPointer: - return "LimitPointer" - case StatusesInts: - return "StatusesInts" - case StatusesIntsPointer: - return "StatusesIntsPointer" - case StatusesStrings: - return "StatusesStrings" - case StatusesStringsPointer: - return "StatusesStringsPointer" - case SortBy: - return "SortBy" - case SortByPointer: - return "SortByPointer" - case OrderHash: - return "OrderHash" - default: - return "Unknown" - } -} - func Parameter(parameter interface{}, variableName string, parameterType ParameterType, validationErrors []error) []error { var err error switch parameterType { + // Generic parameter types case EthereumAddress: if value, ok := parameter.(string); ok { err = CheckEthereumAddress(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(EthereumAddress)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "EthereumAddress") } case EthereumAddressPointer: if value, ok := parameter.(*string); ok { err = CheckEthereumAddressPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(EthereumAddressPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "EthereumAddressPointer") } case BigInt: if value, ok := parameter.(string); ok { err = CheckBigInt(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(BigInt)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "BigInt") } case BigIntPointer: if value, ok := parameter.(*string); ok { err = CheckBigIntPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(BigIntPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "BigIntPointer") } case ChainID: if value, ok := parameter.(int); ok { err = CheckChainId(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, GetParameterTypeName(ChainID)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, "ChainID") } case PrivateKey: if value, ok := parameter.(string); ok { err = CheckPrivateKey(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(PrivateKey)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "PrivateKey") } + + // Swap parameter types case ApprovalType: if value, ok := parameter.(int); ok { err = CheckApprovalType(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, GetParameterTypeName(ApprovalType)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, "ApprovalType") } case Slippage: if value, ok := parameter.(float32); ok { err = CheckSlippage(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Slippage)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Slippage") } + + // Orderbook parameter types case Page: if value, ok := parameter.(float32); ok { err = CheckPage(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Page)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Page") } case PagePointer: if value, ok := parameter.(*float32); ok { err = CheckPagePointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, GetParameterTypeName(PagePointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, "PagePointer") } case Limit: if value, ok := parameter.(float32); ok { err = CheckLimit(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Limit)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Limit") } case LimitPointer: if value, ok := parameter.(*float32); ok { err = CheckLimitPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, GetParameterTypeName(LimitPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, "LimitPointer") } case StatusesInts: if value, ok := parameter.([]float32); ok { err = CheckStatusesInts(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []float32", variableName, GetParameterTypeName(StatusesInts)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []float32", variableName, "StatusesInts") } case StatusesIntsPointer: if value, ok := parameter.(*[]float32); ok { err = CheckStatusesIntsPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]float32 pointer", variableName, GetParameterTypeName(StatusesIntsPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]float32 pointer", variableName, "StatusesIntsPointer") } case StatusesStrings: if value, ok := parameter.([]string); ok { err = CheckStatusesStrings(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []string", variableName, GetParameterTypeName(StatusesStrings)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []string", variableName, "StatusesStrings") } case StatusesStringsPointer: if value, ok := parameter.(*[]string); ok { err = CheckStatusesStringsPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]string pointer", variableName, GetParameterTypeName(StatusesStringsPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]string pointer", variableName, "StatusesStringsPointer") } case SortBy: if value, ok := parameter.(string); ok { err = CheckSortBy(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(SortBy)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "SortBy") } case SortByPointer: if value, ok := parameter.(*string); ok { err = CheckSortByPointer(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(SortByPointer)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "SortByPointer") } case OrderHash: if value, ok := parameter.(string); ok { err = CheckOrderHash(value, variableName) } else { - err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(OrderHash)) + err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "OrderHash") } default: err = errors.New("unknown parameter type") diff --git a/golang/client/validate/parameter_types_test.go b/golang/client/validate/parameter_types_test.go new file mode 100644 index 00000000..e3dd7865 --- /dev/null +++ b/golang/client/validate/parameter_types_test.go @@ -0,0 +1,62 @@ +package validate + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParameterFunctionConsistency(t *testing.T) { + err := checkParameterFunctionConsistency("parameter_types.go") + assert.NoError(t, err, "Parameter function consistency check failed") +} + +func checkParameterFunctionConsistency(filePath string) error { + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + caseRegex := regexp.MustCompile(`^[\t\s]*case\s+(\w+):`) + scanner := bufio.NewScanner(file) + + lineNumber := 0 + for scanner.Scan() { + lineNumber++ + line := scanner.Text() + + matches := caseRegex.FindStringSubmatch(line) + if len(matches) == 2 { + caseLabel := matches[1] + + // Move the scanner four lines down + for i := 0; i < 4; i++ { + if !scanner.Scan() { + return fmt.Errorf("expected error message four lines down from case '%s' at line %d, but reached EOF", caseLabel, lineNumber) + } + lineNumber++ + } + + errorLine := scanner.Text() + if !strings.HasSuffix(errorLine, fmt.Sprintf(`"%s")`, caseLabel)) { + + finalStringLiteralRegex := regexp.MustCompile(`"([^"]+)"\s*\)\s*$`) + stringLiteral := finalStringLiteralRegex.FindStringSubmatch(errorLine) + + return fmt.Errorf("mismatch found at line %d: case '%s' should be used in the error message. Have '%s', want '%s'", lineNumber, caseLabel, stringLiteral[1], caseLabel) + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error scanning file: %v", err) + } + + return nil +} diff --git a/golang/project-validation/directories.go b/golang/project-validation/directories.go new file mode 100644 index 00000000..c1e0414d --- /dev/null +++ b/golang/project-validation/directories.go @@ -0,0 +1,27 @@ +package project_validation + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +const projectValidationError = `this test must be run with the working directory set to the "golang" folder of the SDK project` + +func validateWorkingDirectory() error { + currentDir, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current directory: %v", err) + } + + dirName := filepath.Base(currentDir) + if dirName != "project-validation" { + return fmt.Errorf("%v. Current directory: %v", projectValidationError, currentDir) + } + return nil +} + +func vanityPath(path string) string { + return strings.Replace(path, "..", "1inch-sdk/golang", 1) +} diff --git a/golang/project-validation/integration_check_cleanup_test.go b/golang/project-validation/integration_check_cleanup_test.go new file mode 100644 index 00000000..c3ff6787 --- /dev/null +++ b/golang/project-validation/integration_check_cleanup_test.go @@ -0,0 +1,65 @@ +package project_validation + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +const targetBlock = "helpers.Sleep()" + +func TestIntegrationSleepCleanup(t *testing.T) { + + err := validateWorkingDirectory() + if err != nil { + t.Fatalf("Directory error: %v", err) + } + + err = filepath.Walk("./..", func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if strings.HasSuffix(path, "_integration_test.go") || strings.HasSuffix(path, "_e2e_test.go") { + content, err := os.ReadFile(path) + if err != nil { + return err + } + + inTRunBlock := false + openBracesCount := 0 + tRunContent := "" + + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "t.Run(") { + inTRunBlock = true + } + + if inTRunBlock { + tRunContent += line + "\n" + openBracesCount += strings.Count(line, "{") - strings.Count(line, "}") + + if openBracesCount == 0 { + if !strings.Contains(tRunContent, targetBlock) { + t.Errorf(`File: %s - Missing cleanup block in one of the tests. Each test must have the following code in it: + t.Cleanup(func() { + helpers.Sleep() + }) +`, vanityPath(path)) + } + inTRunBlock = false + tRunContent = "" + } + } + } + } + return nil + }) + if err != nil { + t.Fatalf("Error walking through files: %v", err) + } +} diff --git a/golang/project-validation/integration_test_validation/main.go b/golang/project-validation/integration_test_validation/main.go deleted file mode 100644 index d061e71f..00000000 --- a/golang/project-validation/integration_test_validation/main.go +++ /dev/null @@ -1,84 +0,0 @@ -package main - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" -) - -// This is a helper script to ensure all table-driven integration tests call the helpers.sleep() function -// Can only be run when the working directory is set to the "golang" folder of the SDK - -const targetBlock = "helpers.Sleep()" - -func main() { - currentDir, err := os.Getwd() - if err != nil { - fmt.Println("Failed to get current directory:", err) - return - } - - dirName := filepath.Base(currentDir) - if dirName != "golang" { - fmt.Println(`This script must be run specifically from the "golang" folder of the SDK project.`) - return - } - - err = filepath.Walk(".", func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if strings.HasSuffix(path, "_integration_test.go") || strings.HasSuffix(path, "_e2e_test.go") { - content, err := os.ReadFile(path) - if err != nil { - return err - } - - functionName := extractFunctionName(string(content)) - - inTRunBlock := false - openBracesCount := 0 - tRunContent := "" - - lines := strings.Split(string(content), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - - if strings.HasPrefix(line, "t.Run(") { - inTRunBlock = true - } - - if inTRunBlock { - tRunContent += line + "\n" - openBracesCount += strings.Count(line, "{") - strings.Count(line, "}") - - if openBracesCount == 0 { - if !strings.Contains(tRunContent, targetBlock) { - fmt.Printf("File: %s - Missing cleanup block in the test: %s\n", path, functionName) - } - inTRunBlock = false - tRunContent = "" - } - } - } - } - return nil - }) - if err != nil { - fmt.Println("Failed to walk the path:", err) - return - } - fmt.Println("Done!") -} - -func extractFunctionName(content string) string { - r := regexp.MustCompile(`func (\w+)\(t \*testing.T\)`) - matches := r.FindStringSubmatch(content) - if len(matches) > 1 { - return matches[1] - } - return "Unknown" -} diff --git a/golang/project-validation/check_parameter_validation_function_calls/main.go b/golang/project-validation/parameter_validation_check_format_test.go similarity index 55% rename from golang/project-validation/check_parameter_validation_function_calls/main.go rename to golang/project-validation/parameter_validation_check_format_test.go index 796531e6..da0e2424 100644 --- a/golang/project-validation/check_parameter_validation_function_calls/main.go +++ b/golang/project-validation/parameter_validation_check_format_test.go @@ -1,4 +1,4 @@ -package main +package project_validation import ( "fmt" @@ -8,48 +8,42 @@ import ( "os" "path/filepath" "strings" + "testing" ) -// This is a helper script to ensure all parameter validation functions in the SDK properly pair each parameter name with the correct string -// Can only be run when the working directory is set to the "golang" folder of the SDK +// This is a project validation test to ensure all parameter validation functions in the SDK properly pair each parameter name with the correct string -func main() { - currentDir, err := os.Getwd() - if err != nil { - fmt.Println("Failed to get current directory:", err) - return - } +func TestParameterConsistency(t *testing.T) { - dirName := filepath.Base(currentDir) - if dirName != "golang" { - fmt.Println(`This script must be run specifically from the "golang" folder of the SDK project.`) - return + err := validateWorkingDirectory() + if err != nil { + t.Fatalf("Directory error: %v", err) } - err = filepath.Walk(currentDir, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk("./..", func(path string, info os.FileInfo, err error) error { if err != nil { return err } if strings.HasSuffix(path, "types.go") { - fmt.Println("Checking file:", trimPath(path)) - checkFile(path) + if err := checkFile(t, path); err != nil { + t.Errorf("Failed to check file %v: %v", trimPath(path), err) + } } return nil }) - if err != nil { - fmt.Println("Error walking through files:", err) + t.Fatalf("Error walking through files: %v", err) } } -func checkFile(filename string) { +func checkFile(t *testing.T, filePath string) error { fset := token.NewFileSet() - node, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments) if err != nil { - fmt.Println("Error parsing file:", err) - return + return fmt.Errorf("error parsing file: %w", err) } + var lastErr error ast.Inspect(node, func(n ast.Node) bool { callExpr, ok := n.(*ast.CallExpr) if !ok { @@ -60,16 +54,20 @@ func checkFile(filename string) { if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "validate" && selExpr.Sel.Name == "Parameter" { if len(callExpr.Args) >= 4 { firstArg, secondArg := callExpr.Args[0], callExpr.Args[1] - checkValidationCall(firstArg, secondArg, fset) + if err := checkValidationCall(firstArg, secondArg, fset); err != nil { + t.Errorf("%v", err) + } } } } return true }) + + return lastErr } -func checkValidationCall(firstArg, secondArg ast.Expr, fset *token.FileSet) { +func checkValidationCall(firstArg, secondArg ast.Expr, fset *token.FileSet) error { // Extract variable name from the first argument varName := extractVarName(firstArg) @@ -88,10 +86,11 @@ func checkValidationCall(firstArg, secondArg ast.Expr, fset *token.FileSet) { formattedVarName := strings.ToLower(string(actualVarNamePart[0])) + actualVarNamePart[1:] if formattedVarName != expectedVarName { - reportMismatch(varName, expectedVarName, stringLit, fset, formattedVarName) + return reportMismatch(varName, expectedVarName, stringLit, fset, formattedVarName) } } } + return nil } func extractVarName(expr ast.Expr) string { @@ -102,15 +101,25 @@ func extractVarName(expr ast.Expr) string { if ident, ok := v.X.(*ast.Ident); ok { return ident.Name + "." + v.Sel.Name } + case *ast.CallExpr: + // This block handles type conversion expressions. + // Check if the Fun part of the CallExpr is an identifier, which would indicate a type conversion. + if ident, ok := v.Fun.(*ast.Ident); ok { + // The first argument of the CallExpr should be the variable being converted. + if len(v.Args) > 0 { + return extractVarName(v.Args[0]) // Recursively extract the name from the first argument. + } + return ident.Name // If no arguments are found, return the type being converted to. + } } return "" } -func reportMismatch(varName, varNameAsString string, stringLit *ast.BasicLit, fset *token.FileSet, formattedVarName string) { +func reportMismatch(varName, varNameAsString string, stringLit *ast.BasicLit, fset *token.FileSet, formattedVarName string) error { position := fset.Position(stringLit.Pos()) displayPath := trimPath(position.Filename) - fmt.Printf("Mismatch found in %s at line %d: parameter '%s' should have the string literal '%s' next to it, not '%s'.\n", - displayPath, position.Line, varName, formattedVarName, varNameAsString) + return fmt.Errorf("Mismatch found in %s at line %d: parameter '%s' should have the string literal '%s' next to it, not '%s'.\n", + vanityPath(displayPath), position.Line, varName, formattedVarName, varNameAsString) } func trimPath(path string) string {