diff --git a/added-vocabulary.go b/added-vocabulary.go index cafeb11..cdc2d60 100644 --- a/added-vocabulary.go +++ b/added-vocabulary.go @@ -7,7 +7,6 @@ import ( "sort" "unicode" - "github.com/sugarme/regexpset" "github.com/sugarme/tokenizer/normalizer" ) @@ -180,8 +179,8 @@ func isWordCharacter(r rune) bool { // matchingSet is a set of regular expression string type matchingSet struct { - regexSet regexpset.RegexpSet - ids []int + ids []int + regexps []*regexp.Regexp } // AddedVocabulary is a vocabulary built on top of the Model @@ -346,17 +345,18 @@ func (av *AddedVocabulary) refreshAddedTokens(model Model, normalizer normalizer } } - normSet, err := regexpset.NewRegexpSet(normPatterns) - if err != nil { - log.Fatal(err) + normRegexps := make([]*regexp.Regexp, len(normPatterns)) + for i, p := range normPatterns { + normRegexps[i] = regexp.MustCompile(p) } - nnormSet, err := regexpset.NewRegexpSet(nnormPatterns) - if err != nil { - log.Fatal(err) + + nnormRegexps := make([]*regexp.Regexp, len(nnormPatterns)) + for i, p := range nnormPatterns { + nnormRegexps[i] = regexp.MustCompile(p) } - av.splitNormalizedRe = matchingSet{*normSet, normIds} - av.splitRe = matchingSet{*nnormSet, nnormIds} + av.splitNormalizedRe = matchingSet{normIds, normRegexps} + av.splitRe = matchingSet{nnormIds, nnormRegexps} } type idOffsets struct { @@ -390,11 +390,9 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re return []idOffsets{{-1, []int{0, 0}}} } - matches := splitRe.regexSet.Matches(sentence).Matches() - var ioPairs []idOffsets + ioPairs := make([]idOffsets, 0, len(splitRe.regexps)*2) - for _, idx := range matches { - r := regexp.MustCompile(splitRe.regexSet.Patterns()[idx]) + for idx, r := range splitRe.regexps { locs := r.FindAllStringIndex(sentence, -1) for _, loc := range locs { id := idx @@ -403,15 +401,22 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re } } - // Sort id-offsets by start then by pattern id - sort.Sort(byStart(ioPairs)) - sort.Sort(byId(ioPairs)) + // Sort id-offsets by start, then by pattern id. + sort.Slice(ioPairs, func(i, j int) bool { + if ioPairs[i].offsets[0] != ioPairs[j].offsets[0] { + return ioPairs[i].offsets[0] < ioPairs[j].offsets[0] + } + if ioPairs[i].offsets[1] != ioPairs[j].offsets[1] { + return ioPairs[i].offsets[1] < ioPairs[j].offsets[1] + } + return ioPairs[i].id < ioPairs[j].id + }) - // Select the matches, if they overlap, keep them + // Select matches greedily. With sort(start, id), overlapping ties pick lowest id. var ( - i int = 0 - currentOffsets int = 0 - splits []idOffsets = make([]idOffsets, 0) + i int = 0 + currentOffsets int = 0 + splits = make([]idOffsets, 0, len(ioPairs)) ) for i < len(ioPairs) { @@ -423,20 +428,6 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re continue } - // Find out whether having overlapping neighbours. - // If so, keep the one with lowest Idx. All other will be skipped - // because `currentOffsets` will have been increased. - if i+1 < len(ioPairs) { - overlapPairs := ioPairs[i:] - sort.Sort(byId(overlapPairs)) - lowestPair := overlapPairs[0] // lowest Id one - splits = append(splits, lowestPair) - currentOffsets = ioPair.offsets[1] - i++ - continue - } - - // Not found overlap neighbours. Just apply itself splits = append(splits, ioPair) currentOffsets = ioPair.offsets[1] i++ @@ -445,7 +436,7 @@ func (av *AddedVocabulary) findMatches(sentence string, splitRe matchingSet) (re // Also, insert the splits in-between added tokens, to split the entire string var ( startOffset int = 0 - finalSplits []idOffsets + finalSplits = make([]idOffsets, 0, len(splits)*2+1) ) for _, ioPair := range splits { @@ -501,6 +492,9 @@ func (av *AddedVocabulary) splitWithIndices(sentence *normalizer.NormalizedStrin // input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase // everything, we expect a match. func (av *AddedVocabulary) ExtractAndNormalize(sequence string, n normalizer.Normalizer) *PreTokenizedString { + if len(av.splitRe.regexps) == 0 && len(av.splitNormalizedRe.regexps) == 0 { + return NewPreTokenizedString(sequence) + } pretokenized := NewPreTokenizedString(sequence) @@ -525,12 +519,71 @@ func (av *AddedVocabulary) ExtractAndNormalize(sequence string, n normalizer.Nor return pretok2 } +// ExtractAndNormalizeFast is like ExtractAndNormalize but creates +// NormalizedStrings without offset tracking for better performance. +// Use this when only token IDs are needed and offset mappings are not required. +func (av *AddedVocabulary) ExtractAndNormalizeFast(sequence string, n normalizer.Normalizer) *PreTokenizedString { + if len(av.splitRe.regexps) == 0 && len(av.splitNormalizedRe.regexps) == 0 { + return NewPreTokenizedStringFast(sequence) + } + + pretokenized := NewPreTokenizedStringFast(sequence) + + pretok1 := pretokenized.Split(func(idx int, seq *normalizer.NormalizedString) []SplitIdx { + return av.splitWithIndices(seq, av.splitRe) + }) + + pretok2 := pretok1.Split(func(i int, seq *normalizer.NormalizedString) []SplitIdx { + newSeq := seq + var err error + if n != nil { + newSeq, err = n.Normalize(seq) + if err != nil { + log.Fatal(err) + } + } + return av.splitWithIndices(newSeq, av.splitNormalizedRe) + }) + + return pretok2 +} + type AddedTokenWithId struct { Id int // Id assigned to this token Special bool // whether this is a special token Token AddedToken // the target AddedToken } +// AddTokensWithIds registers tokens with explicit IDs from the tokenizer.json, +// preserving the exact ID assignments rather than computing new ones. +// This is critical for tokenizers with compacted vocabularies where the +// added_tokens array specifies exact ID values that must be respected. +func (av *AddedVocabulary) AddTokensWithIds(tokenIds []AddedTokenWithId, model Model, normalizer normalizer.Normalizer) int { + added := 0 + for _, ti := range tokenIds { + if ti.Token.Content == "" { + continue + } + + // Register with the specified ID, unconditionally. + av.addedTokenMap[ti.Token.Content] = ti.Id + av.addedTokenMapR[ti.Id] = ti.Token.Content + + if ti.Special { + if _, exists := av.specialTokensSet[ti.Token.Content]; !exists { + av.specialTokens = append(av.specialTokens, ti.Token) + av.specialTokensSet[ti.Token.Content] = true + } + } else { + av.addedTokens = append(av.addedTokens, ti.Token) + } + added++ + } + + av.refreshAddedTokens(model, normalizer) + return added +} + // Implement Serialize interface for AddedVocabular: // ================================================= diff --git a/encoding.go b/encoding.go index 341bc7f..e6cd432 100644 --- a/encoding.go +++ b/encoding.go @@ -134,10 +134,54 @@ func NewEncodingFromTokens(tokens []Token, typeId int) (retVal *Encoding) { } func (e *Encoding) Clone() *Encoding { - out := new(Encoding) - err := util.DeepCopy(e, out) - if err != nil { - panic(err) + out := cloneEncoding(*e) + return &out +} + +func cloneRange(r Range) Range { + if r == nil { + return nil + } + out := make(Range, len(r)) + copy(out, r) + return out +} + +func rangesEqual(a, b Range) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func cloneEncoding(in Encoding) Encoding { + out := Encoding{} + + out.Ids = append([]int(nil), in.Ids...) + out.TypeIds = append([]int(nil), in.TypeIds...) + out.Tokens = append([]string(nil), in.Tokens...) + out.SpecialTokenMask = append([]int(nil), in.SpecialTokenMask...) + out.AttentionMask = append([]int(nil), in.AttentionMask...) + out.Words = append([]int(nil), in.Words...) + + out.Offsets = make([][]int, len(in.Offsets)) + for i, o := range in.Offsets { + out.Offsets[i] = append([]int(nil), o...) + } + + out.Overflowing = make([]Encoding, len(in.Overflowing)) + for i := range in.Overflowing { + out.Overflowing[i] = cloneEncoding(in.Overflowing[i]) + } + + out.SequenceRanges = make(map[int]Range, len(in.SequenceRanges)) + for k, r := range in.SequenceRanges { + out.SequenceRanges[k] = cloneRange(r) } return out @@ -459,9 +503,15 @@ func (e *Encoding) MergeWith(pair *Encoding, growingOffsets bool) (retVal *Encod start := originalLen + r[0] end := originalLen + r[r.Len()-1] + 1 newRange := NewRange(start, end) - var oldRange Range - util.DeepCopy(e.SequenceRanges[seqId], oldRange) - e.SequenceRanges[seqId] = util.Merge(oldRange, newRange) + oldRange, ok := e.SequenceRanges[seqId] + if !ok || len(oldRange) == 0 { + e.SequenceRanges[seqId] = newRange + continue + } + if rangesEqual(oldRange, newRange) { + continue + } + e.SequenceRanges[seqId] = util.Merge(cloneRange(oldRange), newRange) } } @@ -519,8 +569,16 @@ func mergeEncoding(en1, en2 Encoding, growingOffsets bool) Encoding { start := originalLen + r[0] end := originalLen + r[r.Len()-1] + 1 newRange := NewRange(start, end) - oldRange := en1.SequenceRanges[seqId] - sequenceRanges[seqId] = append(oldRange, newRange...) + oldRange, ok := en1.SequenceRanges[seqId] + if !ok || len(oldRange) == 0 { + sequenceRanges[seqId] = newRange + continue + } + if rangesEqual(oldRange, newRange) { + sequenceRanges[seqId] = oldRange + continue + } + sequenceRanges[seqId] = util.Merge(cloneRange(oldRange), newRange) } } else { sequenceRanges = en1.SequenceRanges @@ -586,7 +644,7 @@ func (e *Encoding) pad(targetLength, padId, padTypeId int, padToken string, dire for i := 0; i < len(newTypeIds); i++ { newTypeIds[i] = padTypeId } - newTypeIds = append(newTypeIds, e.Ids...) + newTypeIds = append(newTypeIds, e.TypeIds...) e.TypeIds = newTypeIds newTokens := make([]string, padLength) diff --git a/go.mod b/go.mod index 8076b5c..fe2203b 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.23.0 toolchain go1.24.4 require ( + github.com/dlclark/regexp2 v1.11.5 github.com/emirpasic/gods v1.18.1 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rivo/uniseg v0.4.7 github.com/schollz/progressbar/v2 v2.15.0 - github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c golang.org/x/sync v0.14.0 golang.org/x/text v0.25.0 ) diff --git a/go.sum b/go.sum index d4b8ed6..1ab8b1b 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= @@ -17,8 +19,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4= -github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw= golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= diff --git a/model/bpe/bpe.go b/model/bpe/bpe.go index d408818..86e97bf 100644 --- a/model/bpe/bpe.go +++ b/model/bpe/bpe.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "os" "path/filepath" - "regexp" "sort" // "strconv" @@ -276,9 +275,8 @@ func (b *BPE) ReadFiles(vocabF string, mergesF string) (*model.Vocab, *Merges, e for s.Scan() { line := s.Text() - // Skip line with `#version` - re := regexp.MustCompile(`#version`) - if re.MatchString(line) { + // Skip version header line. + if strings.HasPrefix(line, "#version") { continue } diff --git a/normalizer/normalized.go b/normalizer/normalized.go index cd1d539..db26d1f 100644 --- a/normalizer/normalized.go +++ b/normalizer/normalized.go @@ -149,6 +149,11 @@ type NormalizedString struct { // of the missing part, so that we can still give offsets from this original // string. originalShift int + // When true, skip all alignment tracking for performance. The resulting + // NormalizedString cannot convert offsets between original and normalized + // referentials, but all other operations (Split, Transform, etc.) work + // correctly. Use NewNormalizedFromFast to create such instances. + skipOffsets bool } // NewNormalizedFrom creates a Normalized instance from string input @@ -168,6 +173,21 @@ func NewNormalizedFrom(s string) (retVal *NormalizedString) { } } +// NewNormalizedFromFast creates a NormalizedString without alignment tracking. +// This is significantly faster than NewNormalizedFrom because it skips the +// allocation of per-byte alignment arrays. The resulting NormalizedString +// cannot convert offsets between original and normalized referentials, but +// all mutation operations (Split, Transform, Prepend, etc.) still produce +// correct normalized text. Use this when you only need token IDs and do not +// need character-level offset mappings back to the original input. +func NewNormalizedFromFast(s string) *NormalizedString { + return &NormalizedString{ + original: s, + normalized: s, + skipOffsets: true, + } +} + // createALigns creates alignments from input string. // NOTE:It is used in `NewNormalizedFrom` to create 2 slices // (alignments and alignmentsOriginal) without sharing data @@ -374,6 +394,15 @@ func (n *NormalizedString) Slice(inputRange *Range) (retVal *NormalizedString) { return nil } + if n.skipOffsets { + r := fullRange.IntoFullRange(len(n.normalized)) + return &NormalizedString{ + original: n.normalized[r.start:r.end], + normalized: n.normalized[r.start:r.end], + skipOffsets: true, + } + } + // 1. Find range on normalized (nRange) and original string (oRange) var nRange, oRange *Range switch fullRange.indexOn { @@ -402,6 +431,9 @@ func (n *NormalizedString) Slice(inputRange *Range) (retVal *NormalizedString) { sOriginalShift int ) + sAlignments = make([][]int, 0, nRange.end-nRange.start) + sAlignmentsOriginal = make([][]int, 0, oRange.end-oRange.start) + sOriginal = n.RangeOriginal(fullRange) sNormalized = n.Range(fullRange) for _, a := range n.alignments[nRange.start:nRange.end] { @@ -439,6 +471,23 @@ type ChangeMap struct { // of removed chars at the very beginning. func (n *NormalizedString) TransformRange(inputRange *Range, changeMap []ChangeMap, initialOffset int) (retVal *NormalizedString) { + if n.skipOffsets { + var nStart, nEnd int + switch inputRange.indexOn { + case NormalizedTarget: + r := inputRange.IntoFullRange(n.Len()) + nStart, nEnd = r.start, r.end + default: // OriginalTarget — only from Transform() which uses the full range + nStart, nEnd = 0, len(n.normalized) + } + var buf strings.Builder + for _, item := range changeMap { + buf.WriteString(item.RuneVal) + } + n.normalized = n.normalized[:nStart] + buf.String() + n.normalized[nEnd:] + return n + } + // fmt.Printf("normalized: %+v\n", n) // fmt.Printf("inputRange: %v\n", inputRange) @@ -1065,6 +1114,10 @@ func (n *NormalizedString) Filter(fn func(rune) bool) (retVal *NormalizedString) // Prepend adds given string to the begining of NormalizedString func (n *NormalizedString) Prepend(s string) (retVal *NormalizedString) { + if n.skipOffsets { + n.normalized = s + n.normalized + return n + } chars := []rune(n.normalized) var changeMap []ChangeMap if len(chars) == 0 { @@ -1089,6 +1142,10 @@ func (n *NormalizedString) Prepend(s string) (retVal *NormalizedString) { // Append adds given string to the end of NormalizedString func (n *NormalizedString) Append(s string) (retVal *NormalizedString) { + if n.skipOffsets { + n.normalized = n.normalized + s + return n + } if n.normalized == "" { return n @@ -1196,6 +1253,7 @@ func (n *NormalizedString) Split(pattern Pattern, behavior SplitDelimiterBehavio // Process the matches according to the selected behavior: []OfssetsMatch // where `Match` field is `shouldRemove` var splits []OffsetsMatch + splits = make([]OffsetsMatch, 0, len(matches)) switch behavior { case IsolatedBehavior: for _, m := range matches { @@ -1206,7 +1264,7 @@ func (n *NormalizedString) Split(pattern Pattern, behavior SplitDelimiterBehavio splits = matches case MergedWithPreviousBehavior: previousMatch := false - var acc []OffsetsMatch + acc := make([]OffsetsMatch, 0, len(matches)) for _, m := range matches { if m.Match && !previousMatch { if len(acc) > 0 { @@ -1224,7 +1282,7 @@ func (n *NormalizedString) Split(pattern Pattern, behavior SplitDelimiterBehavio splits = acc case ContiguousBehavior: previousMatch := false - var acc []OffsetsMatch + acc := make([]OffsetsMatch, 0, len(matches)) for _, m := range matches { if m.Match == previousMatch { if len(acc) > 0 { @@ -1243,7 +1301,7 @@ func (n *NormalizedString) Split(pattern Pattern, behavior SplitDelimiterBehavio case MergedWithNextBehavior: previousMatch := false - var acc []OffsetsMatch + acc := make([]OffsetsMatch, 0, len(matches)) // iterate reversively for i := len(matches) - 1; i >= 0; i-- { m := matches[i] @@ -1268,7 +1326,7 @@ func (n *NormalizedString) Split(pattern Pattern, behavior SplitDelimiterBehavio } // Then split according to the computed splits - var slices []NormalizedString + slices := make([]NormalizedString, 0, len(splits)) for _, split := range splits { if !split.Match { slice := n.Slice(NewRange(split.Offsets[0], split.Offsets[1], NormalizedTarget)) diff --git a/normalizer/pattern.go b/normalizer/pattern.go index 1a5ebe7..b4a834e 100644 --- a/normalizer/pattern.go +++ b/normalizer/pattern.go @@ -1,13 +1,67 @@ package normalizer import ( - "log" - // "reflect" "regexp" + "sync" + "unicode/utf8" - "github.com/sugarme/tokenizer/util" + "github.com/dlclark/regexp2" ) +var runeToBytePool = sync.Pool{ + New: func() any { + buf := make([]int, 0, 256) + return &buf + }, +} + +var regexp2MatchPool = sync.Pool{ + New: func() any { + buf := make([][2]int, 0, 16) + return &buf + }, +} + +func getRuneToByteScratch(minCap int) *[]int { + buf := runeToBytePool.Get().(*[]int) + if cap(*buf) < minCap { + *buf = make([]int, 0, minCap) + } else { + *buf = (*buf)[:0] + } + return buf +} + +func putRuneToByteScratch(buf *[]int) { + const maxRetainedCap = 1 << 20 + if cap(*buf) > maxRetainedCap { + *buf = make([]int, 0, 256) + } else { + *buf = (*buf)[:0] + } + runeToBytePool.Put(buf) +} + +func getRegexp2MatchScratch(minCap int) *[][2]int { + buf := regexp2MatchPool.Get().(*[][2]int) + if cap(*buf) < minCap { + *buf = make([][2]int, 0, minCap) + } else { + *buf = (*buf)[:0] + } + return buf +} + +func putRegexp2MatchScratch(buf *[][2]int) { + const maxRetainedCap = 1 << 14 + if cap(*buf) > maxRetainedCap { + *buf = make([][2]int, 0, 16) + } else { + *buf = (*buf)[:0] + } + regexp2MatchPool.Put(buf) +} + // Pattern is used to split a NormalizedString type Pattern interface { // FindMatches slices the given string in a list of pattern match positions, with @@ -55,7 +109,7 @@ func (r *RunePattern) FindMatches(inside string) []OffsetsMatch { for byteIdx, char := range inside { if char == r.rune { - nextIdx := byteIdx + len(string(char)) + nextIdx := byteIdx + utf8.RuneLen(char) // 1. Add previous unmatched if any if hasPrevious { prev := OffsetsMatch{Offsets: []int{prevStart, byteIdx}, Match: false} @@ -89,16 +143,18 @@ func (r *RunePattern) FindMatches(inside string) []OffsetsMatch { // String is a wrapper of primitive string // so that it can implement `Pattern` interface type StringPattern struct { - string + s string + re *regexp.Regexp } func NewStringPattern(s string) *StringPattern { - return &StringPattern{s} + quoted := regexp.QuoteMeta(s) + return &StringPattern{s: s, re: regexp.MustCompile(quoted)} } func (s *StringPattern) FindMatches(inside string) []OffsetsMatch { // If we try to find the matches with an empty string, just don't match anything - if s.string == "" { + if s.s == "" { return []OffsetsMatch{ { Offsets: []int{0, len(inside)}, @@ -107,16 +163,10 @@ func (s *StringPattern) FindMatches(inside string) []OffsetsMatch { } } - quoted := regexp.QuoteMeta(s.string) - - re := regexp.MustCompile(quoted) - - return findMatches(re, inside) + return findMatches(s.re, inside) } -func findMatches(re *regexp.Regexp, inside string) []OffsetsMatch { - - matches := re.FindAllStringIndex(inside, -1) +func buildMatchesFromIndices(matches [][]int, inside string) []OffsetsMatch { // 0. If no matches, just return if len(matches) == 0 { @@ -130,7 +180,7 @@ func findMatches(re *regexp.Regexp, inside string) []OffsetsMatch { var ( currIndex int = 0 - subs []OffsetsMatch + subs = make([]OffsetsMatch, 0, len(matches)*2+1) ) // 1. Sub before matched if any @@ -184,15 +234,116 @@ func findMatches(re *regexp.Regexp, inside string) []OffsetsMatch { return subs } +func findMatches(re *regexp.Regexp, inside string) []OffsetsMatch { + matches := re.FindAllStringIndex(inside, -1) + return buildMatchesFromIndices(matches, inside) +} + +// collectRegexp2Indices collects all match [start, end) rune offsets from re +// into dst, reusing the slice to avoid per-call allocation. +func collectRegexp2Indices(re *regexp2.Regexp, s string, dst [][2]int) [][2]int { + dst = dst[:0] + m, err := re.FindStringMatch(s) + for err == nil && m != nil { + dst = append(dst, [2]int{m.Index, m.Index + m.Length}) + m, err = re.FindNextMatch(m) + } + return dst +} + +func findMatchesRegexp2(re *regexp2.Regexp, inside string) []OffsetsMatch { + asciiOnly := true + for i := 0; i < len(inside); i++ { + if inside[i] >= utf8.RuneSelf { + asciiOnly = false + break + } + } + + var ( + runeToByte []int + runeToByteBuf *[]int + ) + if !asciiOnly { + runeToByteBuf = getRuneToByteScratch(utf8.RuneCountInString(inside) + 1) + runeToByte = *runeToByteBuf + for byteIdx := range inside { + runeToByte = append(runeToByte, byteIdx) + } + runeToByte = append(runeToByte, len(inside)) + *runeToByteBuf = runeToByte + } + + toByte := func(runeIdx int) int { + if runeIdx < 0 { + return 0 + } + if asciiOnly { + if runeIdx > len(inside) { + return len(inside) + } + return runeIdx + } + if runeIdx >= len(runeToByte) { + return len(inside) + } + return runeToByte[runeIdx] + } + + runeMatchesBuf := getRegexp2MatchScratch(8) + runeMatches := collectRegexp2Indices(re, inside, *runeMatchesBuf) + if len(runeMatches) == 0 { + putRegexp2MatchScratch(runeMatchesBuf) + if runeToByteBuf != nil { + putRuneToByteScratch(runeToByteBuf) + } + return []OffsetsMatch{{Offsets: []int{0, len(inside)}, Match: false}} + } + + matches := make([]OffsetsMatch, 0, len(runeMatches)*2+1) + curr := 0 + for _, rm := range runeMatches { + start := toByte(rm[0]) + end := toByte(rm[1]) + if start > curr { + matches = append(matches, OffsetsMatch{Offsets: []int{curr, start}, Match: false}) + } + matches = append(matches, OffsetsMatch{Offsets: []int{start, end}, Match: true}) + curr = end + } + *runeMatchesBuf = runeMatches + putRegexp2MatchScratch(runeMatchesBuf) + + if runeToByteBuf != nil { + putRuneToByteScratch(runeToByteBuf) + } + + if curr < len(inside) { + matches = append(matches, OffsetsMatch{Offsets: []int{curr, len(inside)}, Match: false}) + } + + return matches +} + +// RegexpPattern uses github.com/dlclark/regexp2 for regex matching, +// which supports lookaheads, lookbehinds, and other features not +// available in Go's standard regexp package. This enables compatibility +// with tokenizer patterns used by GPT-4, Qwen, Llama 3, and other +// modern models that rely on .NET/PCRE-style regex syntax. type RegexpPattern struct { - re *regexp.Regexp + re *regexp2.Regexp + source string } +// NewRegexpPattern compiles the given pattern using regexp2 and returns +// a RegexpPattern that implements the Pattern interface. Panics if the +// pattern cannot be compiled. func NewRegexpPattern(s string) *RegexpPattern { - re := regexp.MustCompile(s) - return &RegexpPattern{ - re: re, + re, err := regexp2.Compile(s, 0) + if err != nil { + panic(err) } + return &RegexpPattern{re: re, source: s} } // FindMatches implements Pattern interface for RegexpPattern @@ -206,7 +357,7 @@ func (rp *RegexpPattern) FindMatches(inside string) []OffsetsMatch { } } - return findMatches(rp.re, inside) + return findMatchesRegexp2(rp.re, inside) } // PatternFn is a func type to apply pattern @@ -239,7 +390,7 @@ func (fp *FnPattern) FindMatches(inside string) []OffsetsMatch { for byteIdx, char := range inside { if fp.fn(char) { - nextIdx := byteIdx + len(string(char)) + nextIdx := byteIdx + utf8.RuneLen(char) // 1. Add previous unmatched if any if hasPrevious { prev := OffsetsMatch{Offsets: []int{prevStart, byteIdx}, Match: false} @@ -284,30 +435,14 @@ func NewInvertPattern(p Pattern) *Invert { // FindMatches implement Pattern interface for Invert func (i *Invert) FindMatches(inside string) []OffsetsMatch { - var matches []OffsetsMatch - typ := util.GetType(i.Pattern) - switch typ { - case "*StringPattern": - matches = i.Pattern.(*StringPattern).FindMatches(inside) - case "*RunePattern": - matches = i.Pattern.(*RunePattern).FindMatches(inside) - case "*FnPattern": - matches = i.Pattern.(*FnPattern).FindMatches(inside) - case "*RegexpPattern": - matches = i.Pattern.(*RegexpPattern).FindMatches(inside) - - default: - log.Fatalf("Unsupported type - %q\n", typ) - } - - return invert(matches) + return invert(i.Pattern.FindMatches(inside)) } func invert(matches []OffsetsMatch) (retVal []OffsetsMatch) { - var res []OffsetsMatch - for _, m := range matches { + res := make([]OffsetsMatch, len(matches)) + for i, m := range matches { m.Match = !m.Match - res = append(res, m) + res[i] = m } return res diff --git a/pretokenizer.go b/pretokenizer.go index 6a55545..d2c275a 100644 --- a/pretokenizer.go +++ b/pretokenizer.go @@ -5,6 +5,7 @@ package tokenizer import ( "fmt" "log" + "unicode/utf8" // "reflect" "github.com/sugarme/tokenizer/normalizer" @@ -71,14 +72,24 @@ type SplitFn func(int, *normalizer.NormalizedString) []SplitIdx // func (pt *PreTokenizedString) Split(splitFn SplitFn) *PreTokenizedString { func (pt *PreTokenizedString) Split(splitFn SplitFn) *PreTokenizedString { - var newSplits []Split + newSplits := make([]Split, 0, len(pt.splits)) for i, originalSplit := range pt.splits { if originalSplit.tokens != nil { newSplits = append(newSplits, originalSplit) continue } - for _, splitIdx := range splitFn(i, originalSplit.normalized) { + splitIdxs := splitFn(i, originalSplit.normalized) + if cap(newSplits)-len(newSplits) < len(splitIdxs) { + grow := len(splitIdxs) + if grow < len(pt.splits) { + grow = len(pt.splits) + } + tmp := make([]Split, len(newSplits), len(newSplits)+grow) + copy(tmp, newSplits) + newSplits = tmp + } + for _, splitIdx := range splitIdxs { if splitIdx.Normalized.GetNormalized() != "" { // split := NewSplit(splitIdx.Normalized, splitIdx.Tokens) split := Split{ @@ -98,7 +109,7 @@ func (pt *PreTokenizedString) Split(splitFn SplitFn) *PreTokenizedString { // using the provided `normalize` function. func (pt *PreTokenizedString) Normalize(nFn func(*normalizer.NormalizedString) *normalizer.NormalizedString) *PreTokenizedString { - var nSplits []Split + nSplits := make([]Split, 0, len(pt.splits)) for _, split := range pt.splits { if split.tokens == nil { @@ -115,7 +126,7 @@ func (pt *PreTokenizedString) Normalize(nFn func(*normalizer.NormalizedString) * // Tokenize tokenizes all the splits that do not have attached `Tokens`, using the provided // `tokenize` function func (pt *PreTokenizedString) Tokenize(tokFn func(*normalizer.NormalizedString) ([]Token, error)) (*PreTokenizedString, error) { - var nSplits []Split + nSplits := make([]Split, 0, len(pt.splits)) for _, split := range pt.splits { newSplit := split @@ -160,7 +171,8 @@ func (pt *PreTokenizedString) IntoEncoding(typeId int, wordIdx int, offsetType O currRuneIdx := 0 for byteIdx, r := range pt.original { n := 0 - for i := 0; i < len([]byte(string(r))); i++ { + runeLen := utf8.RuneLen(r) + for i := 0; i < runeLen; i++ { charMap[byteIdx+n] = currRuneIdx n += 1 } @@ -254,6 +266,32 @@ func (pt *PreTokenizedString) IntoEncoding(typeId int, wordIdx int, offsetType O return en, nil } +// IntoIDs extracts just the token IDs from a tokenized PreTokenizedString. +// This is a faster alternative to IntoEncoding when only IDs are needed, +// as it skips offset conversion and Encoding struct construction. +// This method will fail if some splits do not have associated Tokens. +func (pt *PreTokenizedString) IntoIDs() ([]int, error) { + if len(pt.splits) == 0 { + return nil, nil + } + + n := 0 + for _, s := range pt.splits { + if s.tokens == nil { + return nil, fmt.Errorf("Split has not been tokenized. Call 'PreTokenizedString.Tokenize()' method first.\n") + } + n += len(s.tokens) + } + + ids := make([]int, 0, n) + for _, s := range pt.splits { + for _, tok := range s.tokens { + ids = append(ids, tok.Id) + } + } + return ids, nil +} + // GetSplits returns a list of splits, each of them being a slice of the normalized // string, the associated offsets either in original or normalized // referential, as well as the potention tokens @@ -308,6 +346,13 @@ func NewPreTokenizedString(s string) *PreTokenizedString { return NewPreTokenizedStringFromNS(n) } +// NewPreTokenizedStringFast creates a PreTokenizedString without offset tracking. +// See normalizer.NewNormalizedFromFast for details on the performance trade-off. +func NewPreTokenizedStringFast(s string) *PreTokenizedString { + n := normalizer.NewNormalizedFromFast(s) + return NewPreTokenizedStringFromNS(n) +} + type OffsetConverter interface { Convert(offsets []int) ([]int, error) } @@ -322,7 +367,7 @@ func NewBytesToCharOffsetConverter(sequence string) *BytesToCharOffsetConverter b2c := make(map[int]int) n := 0 for charIdx, char := range chars { - nbytes := len([]byte(string(char))) + nbytes := utf8.RuneLen(char) for i := 0; i < nbytes; i++ { byteIdx := n + i b2c[byteIdx] = charIdx diff --git a/pretokenizer/bytelevel.go b/pretokenizer/bytelevel.go index f3e3ff6..fa93864 100644 --- a/pretokenizer/bytelevel.go +++ b/pretokenizer/bytelevel.go @@ -1,7 +1,6 @@ package pretokenizer import ( - "regexp" "strings" "github.com/sugarme/tokenizer" @@ -17,7 +16,7 @@ import ( // TODO: this RE does not cover the case with trailing whitespace!!! const splitRegStr = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+` -var splitRE = regexp.MustCompile(splitRegStr) +var splitPattern = normalizer.NewRegexpPattern(splitRegStr) var BytesChar map[uint8]string = GenerateBytesChar() @@ -143,6 +142,11 @@ type ByteLevel struct { // Whether the post processing step should trim offsets // to avoid including whitespaces. TrimOffsets bool + + // Whether to apply GPT-2 regex splitting before byte mapping. + // This defaults to true and should be disabled when a prior + // pre-tokenizer already performs splitting. + UseRegex bool } // NewByteLevel returns a default ByteLevel with both @@ -151,6 +155,7 @@ func NewByteLevel() *ByteLevel { return &ByteLevel{ AddPrefixSpace: true, TrimOffsets: true, + UseRegex: true, } } @@ -174,6 +179,12 @@ func (bl *ByteLevel) SetTrimOffsets(v bool) { bl.TrimOffsets = v } +// SetUseRegex controls whether byte-level pretokenization applies +// an additional regex split stage before byte mapping. +func (bl *ByteLevel) SetUseRegex(v bool) { + bl.UseRegex = v +} + // Implement `PreTokenizer` methods for `ByteLevel`: // ================================================= @@ -187,10 +198,13 @@ func (bl *ByteLevel) PreTokenize(pretokenized *tokenizer.PreTokenizedString) (*t newNormalized = normalized.Prepend(" ") } - splitPattern := normalizer.NewRegexpPattern(splitRegStr) + if !bl.UseRegex { + return []tokenizer.SplitIdx{{Normalized: newNormalized, Tokens: nil}} + } + splits := newNormalized.Split(splitPattern, normalizer.IsolatedBehavior) - var splitIdx []tokenizer.SplitIdx + splitIdx := make([]tokenizer.SplitIdx, 0, len(splits)) for _, s := range splits { split := s // NOTE: to deep copy variable otherwise its updated to the last item as we will pass its pointer. splitIdx = append(splitIdx, tokenizer.SplitIdx{Normalized: &split, Tokens: nil}) diff --git a/pretokenizer/bytelevel_test.go b/pretokenizer/bytelevel_test.go index fd25d2d..6b366f2 100644 --- a/pretokenizer/bytelevel_test.go +++ b/pretokenizer/bytelevel_test.go @@ -186,6 +186,34 @@ func TestHandlingOfNewLines(t *testing.T) { } } +func TestNoRegexSplitKeepsSingleSpan(t *testing.T) { + bytelevel := pretokenizer.NewByteLevel() + bytelevel.SetAddPrefixSpace(false) + bytelevel.SetUseRegex(false) + + input := "Hello there\nfriend!" + pretokenized := tokenizer.NewPreTokenizedString(input) + + pretok, err := bytelevel.PreTokenize(pretokenized) + if err != nil { + t.Error(err) + } + + splits := pretok.GetSplits(normalizer.OriginalTarget, tokenizer.Byte) + if len(splits) != 1 { + t.Fatalf("want 1 split when use_regex=false, got %d", len(splits)) + } + + if !reflect.DeepEqual([]int{0, len(input)}, splits[0].Offsets) { + t.Fatalf("want offsets %v, got %v", []int{0, len(input)}, splits[0].Offsets) + } + + decoded := bytelevel.Decode(strings.Split(splits[0].Value, "")) + if decoded != input { + t.Fatalf("want %q after decode, got %q", input, decoded) + } +} + func TestHandlingOfMultipleSpaces(t *testing.T) { bytelevel := pretokenizer.NewByteLevel() diff --git a/pretokenizer/sequence.go b/pretokenizer/sequence.go index f860d9c..789962e 100644 --- a/pretokenizer/sequence.go +++ b/pretokenizer/sequence.go @@ -14,6 +14,13 @@ func NewSequence(pretokenizers []tokenizer.PreTokenizer) *Sequence { return &Sequence{pretokenizers} } +// PreTokenizers returns the underlying pretokenizer slice, allowing +// callers to inspect the pipeline (e.g., to detect a ByteLevel stage +// for fast-path optimizations). +func (p *Sequence) PreTokenizers() []tokenizer.PreTokenizer { + return p.pretokenizers +} + // Implement tokenizer.PreTokenizer for Sequence func (p *Sequence) PreTokenize(v *tokenizer.PreTokenizedString) (*tokenizer.PreTokenizedString, error) { diff --git a/pretokenizer/split.go b/pretokenizer/split.go index 8a2fc93..236eafa 100644 --- a/pretokenizer/split.go +++ b/pretokenizer/split.go @@ -29,7 +29,7 @@ func (s *Split) PreTokenize(pretokenized *tokenizer.PreTokenizedString) (*tokeni invert := normalizer.NewInvertPattern(s.Pattern) splits := normalized.Split(invert, s.Behavior) - var splitIdxs []tokenizer.SplitIdx + splitIdxs := make([]tokenizer.SplitIdx, 0, len(splits)) for _, s := range splits { normalized := s splitIdx := tokenizer.SplitIdx{Normalized: &normalized, Tokens: nil} @@ -45,7 +45,7 @@ func (s *Split) PreTokenize(pretokenized *tokenizer.PreTokenizedString) (*tokeni pretok := pretokenized.Split(func(noop int, normalized *normalizer.NormalizedString) []tokenizer.SplitIdx { splits := normalized.Split(s.Pattern, s.Behavior) - var splitIdxs []tokenizer.SplitIdx + splitIdxs := make([]tokenizer.SplitIdx, 0, len(splits)) for _, s := range splits { normalized := s splitIdx := tokenizer.SplitIdx{Normalized: &normalized, Tokens: nil} diff --git a/pretrained/added-tokens.go b/pretrained/added-tokens.go index 253c012..ff8b033 100644 --- a/pretrained/added-tokens.go +++ b/pretrained/added-tokens.go @@ -22,3 +22,26 @@ func CreateAddedTokens(data []tokenizer.TokenConfig) (specialToks, toks []tokeni return specialToks, toks } + +// CreateAddedTokensWithIds preserves the explicit IDs from tokenizer.json +// instead of letting AddedVocabulary recompute them. This is required for +// tokenizers with compacted vocabularies where added token IDs are not +// simply model.GetVocabSize() + offset. +func CreateAddedTokensWithIds(data []tokenizer.TokenConfig) []tokenizer.AddedTokenWithId { + result := make([]tokenizer.AddedTokenWithId, 0, len(data)) + for _, d := range data { + tok := tokenizer.DefaultAddedToken() + tok.Content = d.Content + tok.LStrip = d.Lstrip + tok.Normalized = d.Normalized + tok.RStrip = d.Rstrip + tok.SingleWord = d.SingleWord + + result = append(result, tokenizer.AddedTokenWithId{ + Id: int(d.Id), + Special: d.Special, + Token: tok, + }) + } + return result +} diff --git a/pretrained/pretokenizer.go b/pretrained/pretokenizer.go index 87eb7fb..b4d5b3c 100644 --- a/pretrained/pretokenizer.go +++ b/pretrained/pretokenizer.go @@ -68,10 +68,12 @@ func createByteLevelPreTokenizer(params *util.Params) (tokenizer.PreTokenizer, e addPrefixSpace := params.Get("add_prefix_space", false).(bool) trimOffsets := params.Get("trim_offsets", false).(bool) + useRegex := params.Get("use_regex", true).(bool) return &pretokenizer.ByteLevel{ AddPrefixSpace: addPrefixSpace, TrimOffsets: trimOffsets, + UseRegex: useRegex, }, nil } diff --git a/pretrained/pretokenizer_test.go b/pretrained/pretokenizer_test.go index 47cf7bf..e0a4440 100644 --- a/pretrained/pretokenizer_test.go +++ b/pretrained/pretokenizer_test.go @@ -2,6 +2,8 @@ package pretrained import ( "testing" + + "github.com/sugarme/tokenizer/pretokenizer" ) func TestCreatePreTokenizer(t *testing.T) { @@ -29,3 +31,26 @@ func TestNullPreTokenizer(t *testing.T) { panic(err) } } + +func TestCreateByteLevelUseRegexOption(t *testing.T) { + config := map[string]interface{}{ + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false, + } + + pt, err := CreatePreTokenizer(config) + if err != nil { + t.Fatalf("CreatePreTokenizer error: %v", err) + } + + bl, ok := pt.(*pretokenizer.ByteLevel) + if !ok { + t.Fatalf("expected *pretokenizer.ByteLevel, got %T", pt) + } + + if bl.UseRegex { + t.Fatalf("expected UseRegex=false from config") + } +} diff --git a/pretrained/tokenizer.go b/pretrained/tokenizer.go index e5769eb..23c79a6 100644 --- a/pretrained/tokenizer.go +++ b/pretrained/tokenizer.go @@ -71,13 +71,11 @@ func FromReader(r io.Reader) (*tokenizer.Tokenizer, error) { } tk.WithDecoder(decoder) - // 6. AddedVocabulary - specialAddedTokens, addedTokens := CreateAddedTokens(config.AddedTokens) - if len(specialAddedTokens) > 0 { - tk.AddSpecialTokens(specialAddedTokens) - } - if len(addedTokens) > 0 { - tk.AddTokens(addedTokens) + // 6. AddedVocabulary — use ID-preserving path so that compacted + // tokenizers keep the exact added-token IDs from tokenizer.json. + addedTokensWithIds := CreateAddedTokensWithIds(config.AddedTokens) + if len(addedTokensWithIds) > 0 { + tk.AddTokensWithIds(addedTokensWithIds) } // 7. TruncationParams diff --git a/tokenizer.go b/tokenizer.go index 38f9bf1..4b889d2 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -540,6 +540,19 @@ func (t *Tokenizer) AddTokens(tokens []AddedToken) (retVal int) { return t.addedVocabulary.AddTokens(tokens, t.model, t.normalizer) } +// AddTokensWithIds registers tokens with explicit IDs, preserving +// the exact ID assignments from the tokenizer.json rather than +// computing new sequential IDs. +func (t *Tokenizer) AddTokensWithIds(tokens []AddedTokenWithId) int { + return t.addedVocabulary.AddTokensWithIds(tokens, t.model, t.normalizer) +} + +// GetAddedVocab returns only the added vocabulary (token -> id), +// excluding the base model vocabulary. +func (t *Tokenizer) GetAddedVocab() map[string]int { + return t.addedVocabulary.GetVocab() +} + // doNormalize does Normalization logic, go through all normalizers func (t *Tokenizer) doNormalize(s string) (retVal *normalizer.NormalizedString, err error) { normalized := normalizer.NewNormalizedFrom(s) @@ -641,19 +654,17 @@ func (t *Tokenizer) EncodeBatch(inputs []EncodeInput, addSpecialTokens bool) (re var ( encodings []Encoding = make([]Encoding, len(inputs)) eg errgroup.Group - mu = &sync.Mutex{} ) // Encoding concurrently for i := range inputs { + i := i eg.Go(func() error { e, err := t.Encode(inputs[i], addSpecialTokens) if err != nil { return err } - mu.Lock() encodings[i] = *e - mu.Unlock() return nil }) } @@ -672,7 +683,7 @@ func (t *Tokenizer) EncodeBatch(inputs []EncodeInput, addSpecialTokens bool) (re // DecodeBatch decodes all sentences in concurrency func (t *Tokenizer) DecodeBatch(sentences [][]int, skipSpecialTokens bool) []string { - var decodings []string + decodings := make([]string, len(sentences)) var wg sync.WaitGroup wg.Add(len(sentences)) @@ -683,7 +694,7 @@ func (t *Tokenizer) DecodeBatch(sentences [][]int, skipSpecialTokens bool) []str defer wg.Done() s := t.Decode(sentences[i], skipSpecialTokens) - decodings = append(decodings, s) + decodings[i] = s }(i) } @@ -1102,6 +1113,36 @@ func (t *Tokenizer) EncodePair(input, pair string, addSpecialTokensOpt ...bool) return t.Encode(encodeInput, addSpecialTokens) } +// EncodeIDsOnly returns only the token IDs for the input string. +// It is significantly faster than EncodeSingle because it skips offset +// tracking inside NormalizedString (no per-byte alignment arrays) and +// does not construct a full Encoding struct. +// The returned IDs are identical to EncodeSingle(input).Ids when +// addSpecialTokens is false and no truncation/padding is configured. +func (t *Tokenizer) EncodeIDsOnly(input string) ([]int, error) { + pretokenized := t.addedVocabulary.ExtractAndNormalizeFast(input, t.normalizer) + + if t.preTokenizer != nil { + var err error + pretokenized, err = t.preTokenizer.PreTokenize(pretokenized) + if err != nil { + return nil, err + } + } + + _, err := pretokenized.Tokenize(func(n *normalizer.NormalizedString) ([]Token, error) { + if t.model == nil { + return nil, fmt.Errorf("Tokenizer.EncodeIDsOnly() failed: no Model set") + } + return t.model.Tokenize(n.GetNormalized()) + }) + if err != nil { + return nil, err + } + + return pretokenized.IntoIDs() +} + // Tokenize slices input string into tokens. // // Params: