Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 90 additions & 37 deletions added-vocabulary.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sort"
"unicode"

"github.com/sugarme/regexpset"
"github.com/sugarme/tokenizer/normalizer"
)

Expand Down Expand Up @@ -180,8 +179,8 @@ func isWordCharacter(r rune) bool {

// matchingSet is a set of regular expression string
type matchingSet struct {
regexSet regexpset.RegexpSet
ids []int
ids []int
regexps []*regexp.Regexp
}

// AddedVocabulary is a vocabulary built on top of the Model
Expand Down Expand Up @@ -346,17 +345,18 @@ func (av *AddedVocabulary) refreshAddedTokens(model Model, normalizer normalizer
}
}

normSet, err := regexpset.NewRegexpSet(normPatterns)
if err != nil {
log.Fatal(err)
normRegexps := make([]*regexp.Regexp, len(normPatterns))
for i, p := range normPatterns {
normRegexps[i] = regexp.MustCompile(p)
}
nnormSet, err := regexpset.NewRegexpSet(nnormPatterns)
if err != nil {
log.Fatal(err)

nnormRegexps := make([]*regexp.Regexp, len(nnormPatterns))
for i, p := range nnormPatterns {
nnormRegexps[i] = regexp.MustCompile(p)
}

av.splitNormalizedRe = matchingSet{*normSet, normIds}
av.splitRe = matchingSet{*nnormSet, nnormIds}
av.splitNormalizedRe = matchingSet{normIds, normRegexps}
av.splitRe = matchingSet{nnormIds, nnormRegexps}
}

type idOffsets struct {
Expand Down Expand Up @@ -390,11 +390,9 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re
return []idOffsets{{-1, []int{0, 0}}}
}

matches := splitRe.regexSet.Matches(sentence).Matches()
var ioPairs []idOffsets
ioPairs := make([]idOffsets, 0, len(splitRe.regexps)*2)

for _, idx := range matches {
r := regexp.MustCompile(splitRe.regexSet.Patterns()[idx])
for idx, r := range splitRe.regexps {
locs := r.FindAllStringIndex(sentence, -1)
for _, loc := range locs {
id := idx
Expand All @@ -403,15 +401,22 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re
}
}

// Sort id-offsets by start then by pattern id
sort.Sort(byStart(ioPairs))
sort.Sort(byId(ioPairs))
// Sort id-offsets by start, then by pattern id.
sort.Slice(ioPairs, func(i, j int) bool {
if ioPairs[i].offsets[0] != ioPairs[j].offsets[0] {
return ioPairs[i].offsets[0] < ioPairs[j].offsets[0]
}
if ioPairs[i].offsets[1] != ioPairs[j].offsets[1] {
return ioPairs[i].offsets[1] < ioPairs[j].offsets[1]
}
return ioPairs[i].id < ioPairs[j].id
})

// Select the matches, if they overlap, keep them
// Select matches greedily. With sort(start, id), overlapping ties pick lowest id.
var (
i int = 0
currentOffsets int = 0
splits []idOffsets = make([]idOffsets, 0)
i int = 0
currentOffsets int = 0
splits = make([]idOffsets, 0, len(ioPairs))
)

for i < len(ioPairs) {
Expand All @@ -423,20 +428,6 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re
continue
}

// Find out whether having overlapping neighbours.
// If so, keep the one with lowest Idx. All other will be skipped
// because `currentOffsets` will have been increased.
if i+1 < len(ioPairs) {
overlapPairs := ioPairs[i:]
sort.Sort(byId(overlapPairs))
lowestPair := overlapPairs[0] // lowest Id one
splits = append(splits, lowestPair)
currentOffsets = ioPair.offsets[1]
i++
continue
}

// Not found overlap neighbours. Just apply itself
splits = append(splits, ioPair)
currentOffsets = ioPair.offsets[1]
i++
Expand All @@ -445,7 +436,7 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re
// Also, insert the splits in-between added tokens, to split the entire string
var (
startOffset int = 0
finalSplits []idOffsets
finalSplits = make([]idOffsets, 0, len(splits)*2+1)
)

for _, ioPair := range splits {
Expand Down Expand Up @@ -501,6 +492,9 @@ func (av *AddedVocabulary) splitWithIndices(sentence *normalizer.NormalizedStrin
// input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase
// everything, we expect a match.
func (av *AddedVocabulary) ExtractAndNormalize(sequence string, n normalizer.Normalizer) *PreTokenizedString {
if len(av.splitRe.regexps) == 0 && len(av.splitNormalizedRe.regexps) == 0 {
return NewPreTokenizedString(sequence)
}

pretokenized := NewPreTokenizedString(sequence)

Expand All @@ -525,12 +519,71 @@ func (av *AddedVocabulary) ExtractAndNormalize(sequence string, n normalizer.Nor
return pretok2
}

// ExtractAndNormalizeFast is like ExtractAndNormalize but creates
// NormalizedStrings without offset tracking for better performance.
// Use this when only token IDs are needed and offset mappings are not required.
func (av *AddedVocabulary) ExtractAndNormalizeFast(sequence string, n normalizer.Normalizer) *PreTokenizedString {
if len(av.splitRe.regexps) == 0 && len(av.splitNormalizedRe.regexps) == 0 {
return NewPreTokenizedStringFast(sequence)
}

pretokenized := NewPreTokenizedStringFast(sequence)

pretok1 := pretokenized.Split(func(idx int, seq *normalizer.NormalizedString) []SplitIdx {
return av.splitWithIndices(seq, av.splitRe)
})

pretok2 := pretok1.Split(func(i int, seq *normalizer.NormalizedString) []SplitIdx {
newSeq := seq
var err error
if n != nil {
newSeq, err = n.Normalize(seq)
if err != nil {
log.Fatal(err)
}
}
return av.splitWithIndices(newSeq, av.splitNormalizedRe)
})

return pretok2
}

type AddedTokenWithId struct {
Id int // Id assigned to this token
Special bool // whether this is a special token
Token AddedToken // the target AddedToken
}

// AddTokensWithIds registers tokens with explicit IDs from the tokenizer.json,
// preserving the exact ID assignments rather than computing new ones.
// This is critical for tokenizers with compacted vocabularies where the
// added_tokens array specifies exact ID values that must be respected.
func (av *AddedVocabulary) AddTokensWithIds(tokenIds []AddedTokenWithId, model Model, normalizer normalizer.Normalizer) int {
added := 0
for _, ti := range tokenIds {
if ti.Token.Content == "" {
continue
}

// Register with the specified ID, unconditionally.
av.addedTokenMap[ti.Token.Content] = ti.Id
av.addedTokenMapR[ti.Id] = ti.Token.Content

if ti.Special {
if _, exists := av.specialTokensSet[ti.Token.Content]; !exists {
av.specialTokens = append(av.specialTokens, ti.Token)
av.specialTokensSet[ti.Token.Content] = true
}
} else {
av.addedTokens = append(av.addedTokens, ti.Token)
}
added++
}

av.refreshAddedTokens(model, normalizer)
return added
}

// Implement Serialize interface for AddedVocabular:
// =================================================

Expand Down
78 changes: 68 additions & 10 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,54 @@ func NewEncodingFromTokens(tokens []Token, typeId int) (retVal *Encoding) {
}

func (e *Encoding) Clone() *Encoding {
out := new(Encoding)
err := util.DeepCopy(e, out)
if err != nil {
panic(err)
out := cloneEncoding(*e)
return &out
}

func cloneRange(r Range) Range {
if r == nil {
return nil
}
out := make(Range, len(r))
copy(out, r)
return out
}

func rangesEqual(a, b Range) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

func cloneEncoding(in Encoding) Encoding {
out := Encoding{}

out.Ids = append([]int(nil), in.Ids...)
out.TypeIds = append([]int(nil), in.TypeIds...)
out.Tokens = append([]string(nil), in.Tokens...)
out.SpecialTokenMask = append([]int(nil), in.SpecialTokenMask...)
out.AttentionMask = append([]int(nil), in.AttentionMask...)
out.Words = append([]int(nil), in.Words...)

out.Offsets = make([][]int, len(in.Offsets))
for i, o := range in.Offsets {
out.Offsets[i] = append([]int(nil), o...)
}

out.Overflowing = make([]Encoding, len(in.Overflowing))
for i := range in.Overflowing {
out.Overflowing[i] = cloneEncoding(in.Overflowing[i])
}

out.SequenceRanges = make(map[int]Range, len(in.SequenceRanges))
for k, r := range in.SequenceRanges {
out.SequenceRanges[k] = cloneRange(r)
}

return out
Expand Down Expand Up @@ -459,9 +503,15 @@ func (e *Encoding) MergeWith(pair *Encoding, growingOffsets bool) (retVal *Encod
start := originalLen + r[0]
end := originalLen + r[r.Len()-1] + 1
newRange := NewRange(start, end)
var oldRange Range
util.DeepCopy(e.SequenceRanges[seqId], oldRange)
e.SequenceRanges[seqId] = util.Merge(oldRange, newRange)
oldRange, ok := e.SequenceRanges[seqId]
if !ok || len(oldRange) == 0 {
e.SequenceRanges[seqId] = newRange
continue
}
if rangesEqual(oldRange, newRange) {
continue
}
e.SequenceRanges[seqId] = util.Merge(cloneRange(oldRange), newRange)
}
}

Expand Down Expand Up @@ -519,8 +569,16 @@ func mergeEncoding(en1, en2 Encoding, growingOffsets bool) Encoding {
start := originalLen + r[0]
end := originalLen + r[r.Len()-1] + 1
newRange := NewRange(start, end)
oldRange := en1.SequenceRanges[seqId]
sequenceRanges[seqId] = append(oldRange, newRange...)
oldRange, ok := en1.SequenceRanges[seqId]
if !ok || len(oldRange) == 0 {
sequenceRanges[seqId] = newRange
continue
}
if rangesEqual(oldRange, newRange) {
sequenceRanges[seqId] = oldRange
continue
}
sequenceRanges[seqId] = util.Merge(cloneRange(oldRange), newRange)
}
} else {
sequenceRanges = en1.SequenceRanges
Expand Down Expand Up @@ -586,7 +644,7 @@ func (e *Encoding) pad(targetLength, padId, padTypeId int, padToken string, dire
for i := 0; i < len(newTypeIds); i++ {
newTypeIds[i] = padTypeId
}
newTypeIds = append(newTypeIds, e.Ids...)
newTypeIds = append(newTypeIds, e.TypeIds...)
e.TypeIds = newTypeIds

newTokens := make([]string, padLength)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ go 1.23.0
toolchain go1.24.4

require (
github.com/dlclark/regexp2 v1.11.5
github.com/emirpasic/gods v1.18.1
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/rivo/uniseg v0.4.7
github.com/schollz/progressbar/v2 v2.15.0
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c
golang.org/x/sync v0.14.0
golang.org/x/text v0.25.0
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
Expand All @@ -17,8 +19,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
Expand Down
6 changes: 2 additions & 4 deletions model/bpe/bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"

// "strconv"
Expand Down Expand Up @@ -276,9 +275,8 @@ func (b *BPE) ReadFiles(vocabF string, mergesF string) (*model.Vocab, *Merges, e
for s.Scan() {
line := s.Text()

// Skip line with `#version`
re := regexp.MustCompile(`#version`)
if re.MatchString(line) {
// Skip version header line.
if strings.HasPrefix(line, "#version") {
continue
}

Expand Down
Loading