Skip to content

Commit

Permalink
🪄 clean up the bpe algorithm
Browse files Browse the repository at this point in the history
Renamed variables struct to reflect what they do a little better.
  • Loading branch information
bluescreen10 committed Apr 9, 2023
1 parent 12a3097 commit 4868b0b
Showing 1 changed file with 27 additions and 34 deletions.
61 changes: 27 additions & 34 deletions codec/codec.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package codec

import (
"errors"
"fmt"
"math"
"regexp"
Expand Down Expand Up @@ -58,68 +57,62 @@ func (c *Codec) Decode(tokens []uint) (string, error) {
}

func (c *Codec) bpe(piece []byte) ([]uint, []string) {
type byteRange struct {
start int
end int
type part struct {
offset int
rank uint
}

parts := make([]byteRange, len(piece)+1)
parts := make([]part, len(piece)+1)
for i := 0; i < len(parts); i++ {
parts[i] = byteRange{i, math.MaxInt64}
parts[i] = part{i, math.MaxUint}
}

getRank := func(parts []byteRange, startIdx, skip int) (uint, error) {
if startIdx+skip+2 < len(parts) {
chunk := string(piece[parts[startIdx].start:parts[startIdx+skip+2].start])
if rank, ok := c.vocabulary[chunk]; ok {
return rank, nil
getRank := func(index, skip int) uint {
if index+skip+2 < len(parts) {
start := parts[index].offset
end := parts[index+skip+2].offset
if rank, ok := c.vocabulary[string(piece[start:end])]; ok {
return rank
}
}
return math.MaxInt64, errors.New("not found")
return math.MaxUint
}

for i := 0; i < len(parts)-2; i++ {
if r, err := getRank(parts, i, 0); err == nil {
parts[i].end = int(r)
}
parts[i].rank = getRank(i, 0)
}

for {
if len(parts) == 1 {
break
}

minRank := byteRange{math.MaxInt64, 0}
minRank := uint(math.MaxUint)
minIndex := 0
for i, p := range parts[:len(parts)-1] {
if p.end < minRank.start {
minRank = byteRange{p.end, i}
if p.rank < minRank {
minRank = p.rank
minIndex = i
}
}

if minRank.start != math.MaxInt64 {
i := minRank.end
if minRank == math.MaxUint {
break
}

parts[i].end = math.MaxInt64
if r, err := getRank(parts, i, 1); err == nil {
parts[i].end = int(r)
}
if i > 0 {
parts[i-1].end = math.MaxInt64
if r, err := getRank(parts, i-1, 1); err == nil {
parts[i-1].end = int(r)
}
}
parts[minIndex].rank = getRank(minIndex, 1)

parts = append(parts[:i+1], parts[i+2:]...)
} else {
break
if minIndex > 0 {
parts[minIndex-1].rank = getRank(minIndex-1, 1)
}

parts = append(parts[:minIndex+1], parts[minIndex+2:]...)
}

ids := make([]uint, len(parts)-1)
tokens := make([]string, len(parts)-1)
for i := 0; i < len(ids); i++ {
token := string(piece[parts[i].start:parts[i+1].start])
token := string(piece[parts[i].offset:parts[i+1].offset])
tokens[i] = token
ids[i] = c.vocabulary[token]
}
Expand Down

0 comments on commit 4868b0b

Please sign in to comment.