Skip to content

Commit

Permalink
Improving parameter validation logic and improving project code check…
Browse files Browse the repository at this point in the history
… tests (#11)
  • Loading branch information
Tanz0rz authored Feb 1, 2024
1 parent c22d8cd commit 60bd0dc
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@ import (

type ParameterType int

// Generic parameter types
const (
EthereumAddress ParameterType = iota
EthereumAddressPointer
BigInt
BigIntPointer
ChainID
PrivateKey
ApprovalType
)

// Swap parameter types
const (
ApprovalType ParameterType = iota + 50
Slippage
Page
)

// Orderbook parameter types
const (
Page ParameterType = iota + 100
PagePointer
Limit
LimitPointer
Expand All @@ -29,168 +38,128 @@ const (
OrderHash
)

func GetParameterTypeName(parameterType ParameterType) string {
switch parameterType {
case EthereumAddress:
return "EthereumAddress"
case EthereumAddressPointer:
return "EthereumAddressPointer"
case BigInt:
return "BigInt"
case BigIntPointer:
return "BigIntPointer"
case ChainID:
return "ChainID"
case PrivateKey:
return "PrivateKey"
case ApprovalType:
return "ApprovalType"
case Slippage:
return "Slippage"
case Page:
return "Page"
case PagePointer:
return "PagePointer"
case Limit:
return "Limit"
case LimitPointer:
return "LimitPointer"
case StatusesInts:
return "StatusesInts"
case StatusesIntsPointer:
return "StatusesIntsPointer"
case StatusesStrings:
return "StatusesStrings"
case StatusesStringsPointer:
return "StatusesStringsPointer"
case SortBy:
return "SortBy"
case SortByPointer:
return "SortByPointer"
case OrderHash:
return "OrderHash"
default:
return "Unknown"
}
}

func Parameter(parameter interface{}, variableName string, parameterType ParameterType, validationErrors []error) []error {
var err error

switch parameterType {
// Generic parameter types
case EthereumAddress:
if value, ok := parameter.(string); ok {
err = CheckEthereumAddress(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(EthereumAddress))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "EthereumAddress")
}
case EthereumAddressPointer:
if value, ok := parameter.(*string); ok {
err = CheckEthereumAddressPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(EthereumAddressPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "EthereumAddressPointer")
}
case BigInt:
if value, ok := parameter.(string); ok {
err = CheckBigInt(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(BigInt))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "BigInt")
}
case BigIntPointer:
if value, ok := parameter.(*string); ok {
err = CheckBigIntPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(BigIntPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "BigIntPointer")
}
case ChainID:
if value, ok := parameter.(int); ok {
err = CheckChainId(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, GetParameterTypeName(ChainID))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, "ChainID")
}
case PrivateKey:
if value, ok := parameter.(string); ok {
err = CheckPrivateKey(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(PrivateKey))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "PrivateKey")
}

// Swap parameter types
case ApprovalType:
if value, ok := parameter.(int); ok {
err = CheckApprovalType(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, GetParameterTypeName(ApprovalType))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be an int", variableName, "ApprovalType")
}
case Slippage:
if value, ok := parameter.(float32); ok {
err = CheckSlippage(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Slippage))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Slippage")
}

// Orderbook parameter types
case Page:
if value, ok := parameter.(float32); ok {
err = CheckPage(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Page))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Page")
}
case PagePointer:
if value, ok := parameter.(*float32); ok {
err = CheckPagePointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, GetParameterTypeName(PagePointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, "PagePointer")
}
case Limit:
if value, ok := parameter.(float32); ok {
err = CheckLimit(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, GetParameterTypeName(Limit))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32", variableName, "Limit")
}
case LimitPointer:
if value, ok := parameter.(*float32); ok {
err = CheckLimitPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, GetParameterTypeName(LimitPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a float32 pointer", variableName, "LimitPointer")
}
case StatusesInts:
if value, ok := parameter.([]float32); ok {
err = CheckStatusesInts(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []float32", variableName, GetParameterTypeName(StatusesInts))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []float32", variableName, "StatusesInts")
}
case StatusesIntsPointer:
if value, ok := parameter.(*[]float32); ok {
err = CheckStatusesIntsPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]float32 pointer", variableName, GetParameterTypeName(StatusesIntsPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]float32 pointer", variableName, "StatusesIntsPointer")
}
case StatusesStrings:
if value, ok := parameter.([]string); ok {
err = CheckStatusesStrings(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []string", variableName, GetParameterTypeName(StatusesStrings))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a []string", variableName, "StatusesStrings")
}
case StatusesStringsPointer:
if value, ok := parameter.(*[]string); ok {
err = CheckStatusesStringsPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]string pointer", variableName, GetParameterTypeName(StatusesStringsPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a *[]string pointer", variableName, "StatusesStringsPointer")
}
case SortBy:
if value, ok := parameter.(string); ok {
err = CheckSortBy(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(SortBy))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "SortBy")
}
case SortByPointer:
if value, ok := parameter.(*string); ok {
err = CheckSortByPointer(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, GetParameterTypeName(SortByPointer))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string pointer", variableName, "SortByPointer")
}
case OrderHash:
if value, ok := parameter.(string); ok {
err = CheckOrderHash(value, variableName)
} else {
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, GetParameterTypeName(OrderHash))
err = fmt.Errorf("for parameter '%v' to be validated as '%v', it must be a string", variableName, "OrderHash")
}
default:
err = errors.New("unknown parameter type")
Expand Down
62 changes: 62 additions & 0 deletions golang/client/validate/parameter_types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package validate

import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParameterFunctionConsistency(t *testing.T) {
err := checkParameterFunctionConsistency("parameter_types.go")
assert.NoError(t, err, "Parameter function consistency check failed")
}

func checkParameterFunctionConsistency(filePath string) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()

caseRegex := regexp.MustCompile(`^[\t\s]*case\s+(\w+):`)
scanner := bufio.NewScanner(file)

lineNumber := 0
for scanner.Scan() {
lineNumber++
line := scanner.Text()

matches := caseRegex.FindStringSubmatch(line)
if len(matches) == 2 {
caseLabel := matches[1]

// Move the scanner four lines down
for i := 0; i < 4; i++ {
if !scanner.Scan() {
return fmt.Errorf("expected error message four lines down from case '%s' at line %d, but reached EOF", caseLabel, lineNumber)
}
lineNumber++
}

errorLine := scanner.Text()
if !strings.HasSuffix(errorLine, fmt.Sprintf(`"%s")`, caseLabel)) {

finalStringLiteralRegex := regexp.MustCompile(`"([^"]+)"\s*\)\s*$`)
stringLiteral := finalStringLiteralRegex.FindStringSubmatch(errorLine)

return fmt.Errorf("mismatch found at line %d: case '%s' should be used in the error message. Have '%s', want '%s'", lineNumber, caseLabel, stringLiteral[1], caseLabel)
}
}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("error scanning file: %v", err)
}

return nil
}
27 changes: 27 additions & 0 deletions golang/project-validation/directories.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package project_validation

import (
"fmt"
"os"
"path/filepath"
"strings"
)

const projectValidationError = `this test must be run with the working directory set to the "golang" folder of the SDK project`

func validateWorkingDirectory() error {
currentDir, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current directory: %v", err)
}

dirName := filepath.Base(currentDir)
if dirName != "project-validation" {
return fmt.Errorf("%v. Current directory: %v", projectValidationError, currentDir)
}
return nil
}

func vanityPath(path string) string {
return strings.Replace(path, "..", "1inch-sdk/golang", 1)
}
65 changes: 65 additions & 0 deletions golang/project-validation/integration_check_cleanup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package project_validation

import (
"os"
"path/filepath"
"strings"
"testing"
)

const targetBlock = "helpers.Sleep()"

func TestIntegrationSleepCleanup(t *testing.T) {

err := validateWorkingDirectory()
if err != nil {
t.Fatalf("Directory error: %v", err)
}

err = filepath.Walk("./..", func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if strings.HasSuffix(path, "_integration_test.go") || strings.HasSuffix(path, "_e2e_test.go") {
content, err := os.ReadFile(path)
if err != nil {
return err
}

inTRunBlock := false
openBracesCount := 0
tRunContent := ""

lines := strings.Split(string(content), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)

if strings.HasPrefix(line, "t.Run(") {
inTRunBlock = true
}

if inTRunBlock {
tRunContent += line + "\n"
openBracesCount += strings.Count(line, "{") - strings.Count(line, "}")

if openBracesCount == 0 {
if !strings.Contains(tRunContent, targetBlock) {
t.Errorf(`File: %s - Missing cleanup block in one of the tests. Each test must have the following code in it:
t.Cleanup(func() {
helpers.Sleep()
})
`, vanityPath(path))
}
inTRunBlock = false
tRunContent = ""
}
}
}
}
return nil
})
if err != nil {
t.Fatalf("Error walking through files: %v", err)
}
}
Loading

0 comments on commit 60bd0dc

Please sign in to comment.