Skip to content

Commit

Permalink
better error handling API
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaeguard committed Sep 21, 2023
1 parent e2716be commit dda2a2e
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 79 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ a very simple regex engine written in go.
- [x] extracting the string that matches with the regex
- [x] `\` escape character
- [x] support special characters - context dependant
- [ ] better error handling in the API
- [x] better error handling in the API
- [ ] ability to work on multi-line strings
- [ ] `.` should not match the newline - `\n`
- [ ] `$` should match the newline - `\n`
Expand Down
11 changes: 2 additions & 9 deletions check.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package rgx

import "fmt"

func getChar(input string, pos int) uint8 {
if pos >= 0 && pos < len(input) {
return input[pos]
Expand All @@ -17,15 +15,10 @@ func getChar(input string, pos int) uint8 {
// get the next state given the 'ch' as an input
func (s *State) nextStateWith(ch uint8) *State {
states := s.transitions[ch]

size := len(states)

if size == 0 {
if len(states) == 0 {
return nil
} else if size == 1 {
return states[0]
}
panic(fmt.Sprintf("There must be at most 1 transition, found %d", size))
return states[0]
}

// checks if the inputString is accepted by this NFA
Expand Down
21 changes: 21 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package rgx

import "fmt"

type ParseErrorCode string

const (
SyntaxError ParseErrorCode = "SyntaxError"
Unimplemented = "Unimplemented"
CompilationError = "CompilationError"
)

type RegexError struct {
Code ParseErrorCode
Message string
Pos int
}

func (p *RegexError) Error() string {
return fmt.Sprintf("code=%s, message=%s, pos=%d", p.Code, p.Message, p.Pos)
}
14 changes: 10 additions & 4 deletions lib.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package rgx

// Compile compiles the given regex string
func Compile(regexString string) *State {
func Compile(regexString string) (*State, *RegexError) {
parseContext := parsingContext{
pos: 0,
tokens: []regexToken{},
capturedGroups: map[string]bool{},
}
regex(regexString, &parseContext)
if err := parse(regexString, &parseContext); err != nil {
return nil, err
}
return toNfa(&parseContext)
}

Expand Down Expand Up @@ -70,6 +72,10 @@ func (s *State) FindMatches(inputString string) []Result {
}

// Check compiles the regexString and tests the inputString against it
func Check(regexString string, inputString string) Result {
return Compile(regexString).Test(inputString)
func Check(regexString string, inputString string) (Result, *RegexError) {
compiledNfa, err := Compile(regexString)
if err != nil {
return Result{}, err
}
return compiledNfa.Test(inputString), nil
}
20 changes: 15 additions & 5 deletions lib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ func TestCheck(t *testing.T) {
for _, test := range data {
testName := fmt.Sprintf("%s-%s-%t", test.regexString, test.input, test.expected)
t.Run(testName, func(t *testing.T) {
if test.expected != Check(test.regexString, test.input).matches {
_ = fmt.Errorf("test %s failed", testName)
t.Fail()
result, err := Check(test.regexString, test.input)
if err != nil {
t.Errorf(err.Error())
}
if test.expected != result.matches {
t.Errorf("test %s failed", testName)
}
})
}
Expand All @@ -152,7 +155,11 @@ func TestFindMatches(t *testing.T) {
for _, test := range data {
testName := fmt.Sprintf("%s-%s-%v", test.regexString, test.input, test.expected)
t.Run(testName, func(t *testing.T) {
results := Compile(test.regexString).FindMatches(test.input)
pattern, err := Compile(test.regexString)
if err != nil {
t.Errorf(err.Error())
}
results := pattern.FindMatches(test.input)
if len(results) != len(test.expected) {
t.Fail()
}
Expand All @@ -179,7 +186,10 @@ func TestCheckForDev(t *testing.T) {
testName := fmt.Sprintf("%s-%s-%t", test.regexString, test.input, test.expected)
t.Run(testName, func(t *testing.T) {
dumpDotGraphForRegex(test.regexString)
result := Check(test.regexString, test.input)
result, err := Check(test.regexString, test.input)
if err != nil {
t.Errorf(err.Error())
}
if test.expected != result.matches {
_ = fmt.Errorf("test %s failed", testName)
t.Fail()
Expand Down
103 changes: 74 additions & 29 deletions nfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,24 @@ const (
AnyChar = 3
)

func toNfa(memory *parsingContext) *State {
func toNfa(memory *parsingContext) (*State, *RegexError) {
startFrom := 0
endAt := len(memory.tokens) - 1

token := memory.tokens[startFrom]
startState, endState := tokenToNfa(token, memory, &State{
startState, endState, err := tokenToNfa(token, memory, &State{
transitions: map[uint8][]*State{},
})

if err != nil {
return nil, err
}

for i := startFrom + 1; i <= endAt; i++ {
_, endNext := tokenToNfa(memory.tokens[i], memory, endState)
_, endNext, err := tokenToNfa(memory.tokens[i], memory, endState)
if err != nil {
return nil, err
}
endState = endNext
}

Expand Down Expand Up @@ -72,32 +79,38 @@ func toNfa(memory *parsingContext) *State {

endState.transitions[EpsilonChar] = append(endState.transitions[EpsilonChar], end)

return start
return start, nil
}

func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*State, *State) {
func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*State, *State, *RegexError) {
switch token.tokenType {
case Literal:
value := token.value.(uint8)
to := &State{
transitions: map[uint8][]*State{},
}
startFrom.transitions[value] = append(startFrom.transitions[value], to)
return startFrom, to
startFrom.transitions[value] = []*State{to}
return startFrom, to, nil
case Quantifier:
return handleQuantifierToToken(token, memory, startFrom)
case Wildcard:
to := &State{
transitions: map[uint8][]*State{},
}

startFrom.transitions[AnyChar] = append(startFrom.transitions[AnyChar], to)
startFrom.transitions[AnyChar] = []*State{to}

return startFrom, to
return startFrom, to, nil
case Or:
values := token.value.([]regexToken)
_, end1 := tokenToNfa(values[0], memory, startFrom)
_, end2 := tokenToNfa(values[1], memory, startFrom)
_, end1, err := tokenToNfa(values[0], memory, startFrom)
if err != nil {
return nil, nil, err
}
_, end2, err := tokenToNfa(values[1], memory, startFrom)
if err != nil {
return nil, nil, err
}

to := &State{
transitions: map[uint8][]*State{},
Expand All @@ -106,17 +119,24 @@ func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*St
end1.transitions[EpsilonChar] = append(end1.transitions[EpsilonChar], to)
end2.transitions[EpsilonChar] = append(end2.transitions[EpsilonChar], to)

return startFrom, to
return startFrom, to, nil
case Group:
v := token.value.(groupTokenPayload)

// concatenate all the elements in the group
start, end := tokenToNfa(v.tokens[0], memory, &State{
start, end, err := tokenToNfa(v.tokens[0], memory, &State{
transitions: map[uint8][]*State{},
})

if err != nil {
return nil, nil, err
}

for i := 1; i < len(v.tokens); i++ {
_, endNext := tokenToNfa(v.tokens[i], memory, end)
_, endNext, err := tokenToNfa(v.tokens[i], memory, end)
if err != nil {
return nil, nil, err
}
end = endNext
}
// concatenation ends
Expand Down Expand Up @@ -156,21 +176,30 @@ func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*St
}

startFrom.transitions[EpsilonChar] = append(startFrom.transitions[EpsilonChar], start)
return startFrom, end
return startFrom, end, nil
case GroupUncaptured:
values := token.value.([]regexToken)

start, end := tokenToNfa(values[0], memory, &State{
start, end, err := tokenToNfa(values[0], memory, &State{
transitions: map[uint8][]*State{},
})

if err != nil {
return nil, nil, err
}

for i := 1; i < len(values); i++ {
_, endNext := tokenToNfa(values[i], memory, end)
_, endNext, err := tokenToNfa(values[i], memory, end)

if err != nil {
return nil, nil, err
}

end = endNext
}

startFrom.transitions[EpsilonChar] = append(startFrom.transitions[EpsilonChar], start)
return startFrom, end
return startFrom, end, nil
case Bracket:
constructTokens := token.value.([]regexToken)

Expand All @@ -183,7 +212,7 @@ func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*St
startFrom.transitions[ch] = []*State{to}
}

return startFrom, to
return startFrom, to, nil
case BracketNot:
constructTokens := token.value.([]regexToken)

Expand All @@ -201,21 +230,24 @@ func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*St
}
startFrom.transitions[AnyChar] = []*State{to}

return startFrom, to
return startFrom, to, nil
case TextBeginning:
to := &State{
transitions: map[uint8][]*State{},
}
startFrom.startOfText = true
startFrom.transitions[EpsilonChar] = append(startFrom.transitions[EpsilonChar], to)
return startFrom, to
return startFrom, to, nil
case TextEnd:
startFrom.endOfText = true
return startFrom, startFrom
return startFrom, startFrom, nil
case Backreference:
groupName := token.value.(string)
if _, ok := memory.capturedGroups[groupName]; !ok {
panic(fmt.Sprintf("Group (%s) does not exist", groupName))
return nil, nil, &RegexError{
Code: CompilationError,
Message: fmt.Sprintf("Group (%s) does not exist", groupName),
}
}
to := &State{
transitions: map[uint8][]*State{},
Expand All @@ -226,13 +258,16 @@ func tokenToNfa(token regexToken, memory *parsingContext, startFrom *State) (*St
target: to,
}

return startFrom, to
return startFrom, to, nil
default:
panic(fmt.Sprintf("unrecognized token: %+v", token))
return nil, nil, &RegexError{
Code: CompilationError,
Message: fmt.Sprintf("unrecognized token: %+v", token),
}
}
}

func handleQuantifierToToken(token regexToken, memory *parsingContext, startFrom *State) (*State, *State) {
func handleQuantifierToToken(token regexToken, memory *parsingContext, startFrom *State) (*State, *State, *RegexError) {
payload := token.value.(quantifier)
// the minimum amount of time the NFA needs to repeat
min := payload.min
Expand Down Expand Up @@ -265,17 +300,27 @@ func handleQuantifierToToken(token regexToken, memory *parsingContext, startFrom
} else {
value = token.value.([]regexToken)[0]
}
previousStart, previousEnd := tokenToNfa(value, memory, &State{
previousStart, previousEnd, err := tokenToNfa(value, memory, &State{
transitions: map[uint8][]*State{},
})

if err != nil {
return nil, nil, err
}

startFrom.transitions[EpsilonChar] = append(startFrom.transitions[EpsilonChar], previousStart)

// starting from 2, because the one above is the first one
for i := 2; i <= total; i++ {
// the same NFA needs to be generated 'total' times
start, end := tokenToNfa(value, memory, &State{
start, end, err := tokenToNfa(value, memory, &State{
transitions: map[uint8][]*State{},
})

if err != nil {
return nil, nil, err
}

// connect the end of the previous one to the start of this one
previousEnd.transitions[EpsilonChar] = append(previousEnd.transitions[EpsilonChar], start)

Expand All @@ -295,5 +340,5 @@ func handleQuantifierToToken(token regexToken, memory *parsingContext, startFrom
if max == QuantifierInfinity {
to.transitions[EpsilonChar] = append(to.transitions[EpsilonChar], previousStart)
}
return startFrom, to
return startFrom, to, nil
}
Loading

0 comments on commit dda2a2e

Please sign in to comment.