Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 35 additions & 81 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type MergePair struct {
// BPETokenizer implements the Tokenizer interface using byte-pair encoding.
// It loads vocabulary and merge rules from HuggingFace tokenizer.json format.
// When scores are set and merges are empty, it falls back to SentencePiece
// unigram encoding using greedy longest-match with score-based selection.
// encoding using greedy leftmost-longest match with score-based tie-breaking,
// matching llama.cpp behavior.
//
// Stable.
type BPETokenizer struct {
Expand Down Expand Up @@ -313,109 +314,62 @@ func (t *BPETokenizer) SetScores(scores []float32) {
}
}

// sentencePieceEncode tokenizes text using Viterbi dynamic programming to find
// the segmentation that maximizes the sum of log-probability scores.
// sentencePieceEncode tokenizes text using greedy leftmost-longest match.
//
// This is used for SentencePiece unigram models that provide vocabulary
// scores but no BPE merge table (e.g., Mistral 7B GGUF). The Viterbi approach
// finds the globally optimal segmentation, unlike greedy longest-match which
// can produce suboptimal splits.
// At each position, the longest vocabulary token is selected. Ties in length
// are broken by score (higher wins). This matches the llama.cpp SentencePiece
// tokenizer (llm_tokenizer_spm::tokenize) and produces the same output as
// HuggingFace for GGUF models. When no vocabulary token matches a byte,
// <0xNN> byte fallback tokens are used.
func (t *BPETokenizer) sentencePieceEncode(text string) []int {
if text == "" {
return nil
}

n := len(text) // byte length
var ids []int
pos := 0
n := len(text)

// Viterbi forward pass: find best segmentation.
// bestScore[i] = best total score for text[:i]
// bestLen[i] = byte length of the last token in the best path ending at i
bestScore := make([]float64, n+1)
bestLen := make([]int, n+1)
for i := range bestScore {
bestScore[i] = math.Inf(-1)
}
bestScore[0] = 0
for pos < n {
// Find the longest matching token at the current position.
bestID := -1
bestLen := 0
bestScore := float32(math.Inf(-1))

for i := 0; i < n; i++ {
if math.IsInf(bestScore[i], -1) {
continue
}
// Try all possible tokens starting at position i.
maxLen := t.maxTokenLen
if maxLen > n-i {
maxLen = n - i
if maxLen > n-pos {
maxLen = n - pos
}

for tokenLen := 1; tokenLen <= maxLen; tokenLen++ {
candidate := text[i : i+tokenLen]
candidate := text[pos : pos+tokenLen]
if id, ok := t.vocab[candidate]; ok {
score := bestScore[i] + float64(t.tokenScore(id))
if score > bestScore[i+tokenLen] {
bestScore[i+tokenLen] = score
bestLen[i+tokenLen] = tokenLen
score := t.tokenScore(id)
// Prefer longest match; break ties by score.
if tokenLen > bestLen || (tokenLen == bestLen && score > bestScore) {
bestID = id
bestLen = tokenLen
bestScore = score
}
}
}
// Byte fallback: use <0xNN> as last resort when no vocab token covers
// position i. Byte tokens get a fixed penalty of -1e6 so they never
// beat real vocabulary tokens in the Viterbi DP. This matches
// llama.cpp / SentencePiece behavior where byte fallback is only
// used for characters that have no vocabulary coverage.
byteToken := fmt.Sprintf("<0x%02X>", text[i])
if id, ok := t.vocab[byteToken]; ok {
_ = id // byte token exists but we ignore its vocab score
score := bestScore[i] + (-1e6)
if score > bestScore[i+1] {
bestScore[i+1] = score
bestLen[i+1] = 1
}
} else {
// Byte token not in vocab; use unknown score as last resort.
score := bestScore[i] + float64(t.unknownScore())
if score > bestScore[i+1] {
bestScore[i+1] = score
bestLen[i+1] = 1
}
}
}

// If we can't reach the end, return nil.
if math.IsInf(bestScore[n], -1) {
return nil
}

// Backtrack to find token sequence.
var tokens []int
pos := n
for pos > 0 {
tokLen := bestLen[pos]
candidate := text[pos-tokLen : pos]
if id, ok := t.vocab[candidate]; ok {
tokens = append(tokens, id)
if bestID >= 0 {
ids = append(ids, bestID)
pos += bestLen
} else {
// Byte fallback for single-byte token.
byteToken := fmt.Sprintf("<0x%02X>", text[pos-tokLen])
// Byte fallback: use <0xNN> for the current byte.
byteToken := fmt.Sprintf("<0x%02X>", text[pos])
if id, ok := t.vocab[byteToken]; ok {
tokens = append(tokens, id)
ids = append(ids, id)
} else {
tokens = append(tokens, t.special.UNK)
ids = append(ids, t.special.UNK)
}
pos++
}
pos -= tokLen
}

// Reverse (we built it backwards).
for i, j := 0, len(tokens)-1; i < j; i, j = i+1, j-1 {
tokens[i], tokens[j] = tokens[j], tokens[i]
}

return tokens
}

// unknownScore returns a very negative score used for byte fallback tokens
// when the <0xNN> token is not in the vocabulary.
func (t *BPETokenizer) unknownScore() float32 {
return -100.0
return ids
}

// tokenScore returns the score for a token ID, or 0 if scores are not set
Expand Down
40 changes: 19 additions & 21 deletions bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,9 @@ func TestSentencePieceUnigram_RoundTrip(t *testing.T) {
func TestSentencePieceUnigram_UnknownChars(t *testing.T) {
tok := makeTestSentencePieceUnigram()

// Characters not in vocab should produce UNK or ▁ tokens via Viterbi.
// Characters not in vocab should produce UNK or ▁ tokens via byte fallback.
// "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab,
// Viterbi will match ▁ first, then x, y, z via byte fallback or UNK.
// greedy will match ▁ first, then x, y, z via byte fallback or UNK.
ids, err := tok.Encode("xyz")
if err != nil {
t.Fatalf("Encode error: %v", err)
Expand All @@ -576,11 +576,11 @@ func TestSentencePieceUnigram_UnknownChars(t *testing.T) {
}
}

func TestSentencePieceUnigram_ViterbiOptimal(t *testing.T) {
func TestSentencePieceUnigram_LongestMatch(t *testing.T) {
tok := makeTestSentencePieceUnigram()

// "Hello" should encode as one token ▁Hello (id=8) via Viterbi,
// since it has the best score (-1.0) vs splitting into subwords.
// "Hello" should encode as one token ▁Hello (id=8) via greedy longest-match,
// since it is the longest matching token at position 0.
ids, err := tok.Encode("Hello")
if err != nil {
t.Fatalf("Encode error: %v", err)
Expand Down Expand Up @@ -608,11 +608,10 @@ func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) {
}
}

func TestSentencePieceUnigram_ViterbiBeatsGreedy(t *testing.T) {
// This test demonstrates that Viterbi finds a better segmentation than greedy.
// Vocab has "▁hel" and "lo" with good scores, and "▁h" and "ello" with worse scores.
// Greedy longest-match would pick "▁hello" if available, or "▁hell" + "o".
// But here we set up scores so "▁hel" + "lo" is strictly better than "▁hell" + "o".
func TestSentencePieceUnigram_GreedyLongestMatch(t *testing.T) {
// Greedy leftmost-longest match picks the longest token at each position.
// For "▁hello": longest match at pos 0 is "▁hell" (4 chars), then "o".
// This matches llama.cpp / HuggingFace SentencePiece behavior.
vocab := map[string]int{
"<unk>": 0,
"<s>": 1,
Expand All @@ -633,11 +632,11 @@ func TestSentencePieceUnigram_ViterbiBeatsGreedy(t *testing.T) {
scores[3] = -5.0 // ▁
scores[4] = -4.0 // ▁h
scores[5] = -3.0 // ▁he
scores[6] = -1.5 // ▁hel (good)
scores[7] = -3.0 // ▁hell (worse than ▁hel)
scores[8] = -3.0 // o (bad)
scores[9] = -4.0 // l (bad)
scores[10] = -1.0 // lo (good)
scores[6] = -1.5 // ▁hel
scores[7] = -3.0 // ▁hell
scores[8] = -3.0 // o
scores[9] = -4.0 // l
scores[10] = -1.0 // lo

special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
tok := NewBPETokenizer(vocab, nil, special, false)
Expand All @@ -648,9 +647,8 @@ func TestSentencePieceUnigram_ViterbiBeatsGreedy(t *testing.T) {
if err != nil {
t.Fatalf("Encode error: %v", err)
}
// Viterbi should choose "▁hel" + "lo" (score -1.5 + -1.0 = -2.5)
// over "▁hell" + "o" (score -3.0 + -3.0 = -6.0).
want := []int{6, 10} // ▁hel, lo
// Greedy picks longest match: "▁hell" (id=7) then "o" (id=8).
want := []int{7, 8} // ▁hell, o
if len(ids) != len(want) {
t.Fatalf("Encode(\"hello\") = %v, want %v", ids, want)
}
Expand Down Expand Up @@ -940,8 +938,8 @@ func TestSentencePieceUnigram_ByteFallbackStillWorksForUnknownChars(t *testing.T

func TestSentencePieceUnigram_AddLeadingSpaceDefault(t *testing.T) {
// Regression test: SetSentencePiece(true) must enable addLeadingSpace so
// the Viterbi receives "▁What" (7 bytes) as input rather than "What" (4 bytes).
// Without addLeadingSpace, the ▁ prefix is missing and the Viterbi produces
// the encoder receives "▁What" (7 bytes) as input rather than "What" (4 bytes).
// Without addLeadingSpace, the ▁ prefix is missing and the encoder produces
// byte-level or character-level fallback tokens instead of matching "▁What".
vocab := map[string]int{
"<unk>": 0,
Expand Down Expand Up @@ -997,7 +995,7 @@ func TestSentencePieceUnigram_AddLeadingSpaceDefault(t *testing.T) {
}
// With addLeadingSpace=true, pre-tokenizer produces:
// ["▁What", "▁is", "▁the", "▁capital", "▁of", "▁France?"]
// The Viterbi should match ▁What (ID 3) as a single token.
// The encoder should match ▁What (ID 3) as a single token.
// Without addLeadingSpace, "What" has no ▁ prefix and falls back to
// character tokens [W, h, a, t] — this was the bug.
want := []int{3, 4, 5, 6, 7, 8, 9}
Expand Down
Loading