From 3e4fb7b49248b9c79ec4792cf376c6d07b15a10d Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 26 Mar 2026 09:33:14 -0700 Subject: [PATCH] fix: implement Viterbi SentencePiece encoding (replaces greedy) The greedy longest-match approach in sentencePieceEncode produced suboptimal tokenization for SentencePiece unigram models (e.g., Mistral 7B). Replace it with Viterbi dynamic programming that finds the segmentation maximizing the sum of log-probability scores. Also adds: - Byte fallback encoding/decoding via <0xNN> tokens for chars not in vocab - decodeSentencePieceBytes for proper round-trip of byte fallback tokens - Tests: Viterbi vs greedy, byte fallback, sentence round-trip, edge cases --- bpe.go | 159 ++++++++++++++++++++++++++--------- bpe_test.go | 235 +++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 347 insertions(+), 47 deletions(-) diff --git a/bpe.go b/bpe.go index 9aa7a1f..b43e058 100644 --- a/bpe.go +++ b/bpe.go @@ -2,6 +2,7 @@ package ztoken import ( "fmt" + "math" "strings" "unicode/utf8" ) @@ -202,6 +203,8 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) { return decoded, nil } if t.sentencePiece { + // Decode <0xNN> byte tokens back to actual bytes. + result = decodeSentencePieceBytes(result) // Replace ▁ with space and trim leading space. result = strings.ReplaceAll(result, "\u2581", " ") result = strings.TrimPrefix(result, " ") @@ -210,6 +213,43 @@ func (t *BPETokenizer) Decode(ids []int) (string, error) { return result, nil } +// decodeSentencePieceBytes replaces <0xNN> hex byte tokens with the +// corresponding raw bytes. This reverses the byte fallback encoding +// used by SentencePiece for characters not in the vocabulary. +func decodeSentencePieceBytes(s string) string { + var sb strings.Builder + i := 0 + for i < len(s) { + // Look for <0xNN> pattern: exactly 6 characters. + if i+6 <= len(s) && s[i] == '<' && s[i+1] == '0' && s[i+2] == 'x' && s[i+5] == '>' { + hi := unhex(s[i+3]) + lo := unhex(s[i+4]) + if hi >= 0 && lo >= 0 { + sb.WriteByte(byte(hi<<4 | lo)) + i += 6 + continue + } + } + sb.WriteByte(s[i]) + i++ + } + return sb.String() +} + +// unhex converts a hex digit character to its value, or -1 if invalid. +func unhex(c byte) int { + switch { + case c >= '0' && c <= '9': + return int(c - '0') + case c >= 'A' && c <= 'F': + return int(c-'A') + 10 + case c >= 'a' && c <= 'f': + return int(c-'a') + 10 + default: + return -1 + } +} + // VocabSize returns the number of tokens in the vocabulary. func (t *BPETokenizer) VocabSize() int { return len(t.vocab) @@ -259,59 +299,104 @@ func (t *BPETokenizer) SetScores(scores []float32) { } } -// sentencePieceEncode tokenizes text using greedy longest-match with -// score-based selection. For each position, it finds all vocabulary tokens -// that match the input at that position, selects the longest match (breaking -// ties by highest score), and advances past the matched token. +// sentencePieceEncode tokenizes text using Viterbi dynamic programming to find +// the segmentation that maximizes the sum of log-probability scores. // // This is used for SentencePiece unigram models that provide vocabulary -// scores but no BPE merge table (e.g., Mistral 7B GGUF). +// scores but no BPE merge table (e.g., Mistral 7B GGUF). The Viterbi approach +// finds the globally optimal segmentation, unlike greedy longest-match which +// can produce suboptimal splits. func (t *BPETokenizer) sentencePieceEncode(text string) []int { if text == "" { return nil } - var ids []int - pos := 0 - textBytes := []byte(text) - n := len(textBytes) - for pos < n { - bestLen := 0 - bestID := t.special.UNK - bestScore := float32(-1e30) - - // Search for the longest matching token at this position. - // Limit search window to maxTokenLen bytes. - maxEnd := pos + t.maxTokenLen - if maxEnd > n { - maxEnd = n - } + n := len(text) // byte length + + // Viterbi forward pass: find best segmentation. + // bestScore[i] = best total score for text[:i] + // bestLen[i] = byte length of the last token in the best path ending at i + bestScore := make([]float64, n+1) + bestLen := make([]int, n+1) + for i := range bestScore { + bestScore[i] = math.Inf(-1) + } + bestScore[0] = 0 - for end := pos + 1; end <= maxEnd; end++ { - candidate := string(textBytes[pos:end]) + for i := 0; i < n; i++ { + if math.IsInf(bestScore[i], -1) { + continue + } + // Try all possible tokens starting at position i. + maxLen := t.maxTokenLen + if maxLen > n-i { + maxLen = n - i + } + for tokenLen := 1; tokenLen <= maxLen; tokenLen++ { + candidate := text[i : i+tokenLen] if id, ok := t.vocab[candidate]; ok { - candidateLen := end - pos - // Prefer longer matches. For equal length, prefer higher score. - if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) { - bestLen = candidateLen - bestID = id - bestScore = t.tokenScore(id) + score := bestScore[i] + float64(t.tokenScore(id)) + if score > bestScore[i+tokenLen] { + bestScore[i+tokenLen] = score + bestLen[i+tokenLen] = tokenLen } } } + // Byte fallback: if no vocab token covers position i, use <0xNN>. + byteToken := fmt.Sprintf("<0x%02X>", text[i]) + if id, ok := t.vocab[byteToken]; ok { + score := bestScore[i] + float64(t.tokenScore(id)) + if score > bestScore[i+1] { + bestScore[i+1] = score + bestLen[i+1] = 1 + } + } else { + // Byte token not in vocab; use unknown score as last resort. + score := bestScore[i] + float64(t.unknownScore()) + if score > bestScore[i+1] { + bestScore[i+1] = score + bestLen[i+1] = 1 + } + } + } - if bestLen == 0 { - // No matching token found; emit UNK and advance by one byte. - ids = append(ids, t.special.UNK) - // Advance past one UTF-8 character, not just one byte. - _, size := decodeRune(textBytes[pos:]) - pos += size + // If we can't reach the end, return nil. + if math.IsInf(bestScore[n], -1) { + return nil + } + + // Backtrack to find token sequence. + var tokens []int + pos := n + for pos > 0 { + tokLen := bestLen[pos] + candidate := text[pos-tokLen : pos] + if id, ok := t.vocab[candidate]; ok { + tokens = append(tokens, id) } else { - ids = append(ids, bestID) - pos += bestLen + // Byte fallback for single-byte token. + byteToken := fmt.Sprintf("<0x%02X>", text[pos-tokLen]) + if id, ok := t.vocab[byteToken]; ok { + tokens = append(tokens, id) + } else { + tokens = append(tokens, t.special.UNK) + } } + pos -= tokLen + } + + // Reverse (we built it backwards). + for i, j := 0, len(tokens)-1; i < j; i, j = i+1, j-1 { + tokens[i], tokens[j] = tokens[j], tokens[i] } - return ids + + return tokens +} + +// unknownScore returns a very negative score used for byte fallback tokens +// when the <0xNN> token is not in the vocabulary. +func (t *BPETokenizer) unknownScore() float32 { + return -100.0 } // tokenScore returns the score for a token ID, or 0 if scores are not set diff --git a/bpe_test.go b/bpe_test.go index 38c6cdd..56fd881 100644 --- a/bpe_test.go +++ b/bpe_test.go @@ -1,6 +1,7 @@ package ztoken import ( + "fmt" "testing" ) @@ -558,27 +559,28 @@ func TestSentencePieceUnigram_RoundTrip(t *testing.T) { func TestSentencePieceUnigram_UnknownChars(t *testing.T) { tok := makeTestSentencePieceUnigram() - // Characters not in vocab should produce UNK tokens. + // Characters not in vocab should produce UNK or ▁ tokens via Viterbi. + // "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab, + // Viterbi will match ▁ first, then x, y, z via byte fallback or UNK. ids, err := tok.Encode("xyz") if err != nil { t.Fatalf("Encode error: %v", err) } - // "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab, - // the greedy matcher will match ▁ first, then x, y, z individually → all UNK. + if len(ids) == 0 { + t.Fatal("expected non-empty token list for 'xyz'") + } for _, id := range ids { - if id != tok.special.UNK { - // Either ▁ (id=3) or UNK (id=0) are acceptable since ▁ is in vocab. - if id != 3 { - t.Errorf("expected UNK or ▁ token for unknown chars, got id=%d", id) - } + if id != tok.special.UNK && id != 3 { + t.Errorf("expected UNK (0) or ▁ (3) token for unknown chars, got id=%d", id) } } } -func TestSentencePieceUnigram_PrefersLongestMatch(t *testing.T) { +func TestSentencePieceUnigram_ViterbiOptimal(t *testing.T) { tok := makeTestSentencePieceUnigram() - // "Hello" should encode as one token ▁Hello (id=8), not ▁H + e + l + l + o. + // "Hello" should encode as one token ▁Hello (id=8) via Viterbi, + // since it has the best score (-1.0) vs splitting into subwords. ids, err := tok.Encode("Hello") if err != nil { t.Fatalf("Encode error: %v", err) @@ -605,3 +607,216 @@ func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) { t.Errorf("with merges present, expected BPE encoding [17], got %v", ids) } } + +func TestSentencePieceUnigram_ViterbiBeatsGreedy(t *testing.T) { + // This test demonstrates that Viterbi finds a better segmentation than greedy. + // Vocab has "▁hel" and "lo" with good scores, and "▁h" and "ello" with worse scores. + // Greedy longest-match would pick "▁hello" if available, or "▁hell" + "o". + // But here we set up scores so "▁hel" + "lo" is strictly better than "▁hell" + "o". + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, + "\u2581h": 4, + "\u2581he": 5, + "\u2581hel": 6, + "\u2581hell": 7, + "o": 8, + "l": 9, + "lo": 10, + } + scores := make([]float32, 11) + scores[0] = -100 // + scores[1] = -100 // + scores[2] = -100 // + scores[3] = -5.0 // ▁ + scores[4] = -4.0 // ▁h + scores[5] = -3.0 // ▁he + scores[6] = -1.5 // ▁hel (good) + scores[7] = -3.0 // ▁hell (worse than ▁hel) + scores[8] = -3.0 // o (bad) + scores[9] = -4.0 // l (bad) + scores[10] = -1.0 // lo (good) + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + + ids, err := tok.Encode("hello") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // Viterbi should choose "▁hel" + "lo" (score -1.5 + -1.0 = -2.5) + // over "▁hell" + "o" (score -3.0 + -3.0 = -6.0). + want := []int{6, 10} // ▁hel, lo + if len(ids) != len(want) { + t.Fatalf("Encode(\"hello\") = %v, want %v", ids, want) + } + for i, id := range ids { + if id != want[i] { + t.Errorf("Encode(\"hello\")[%d] = %d, want %d", i, id, want[i]) + } + } +} + +// makeTestSentencePieceUnigramWithBytes creates a SentencePiece unigram +// tokenizer that includes <0xNN> byte fallback tokens. +func makeTestSentencePieceUnigramWithBytes() *BPETokenizer { + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, + "\u2581the": 4, + "\u2581capital": 5, + "\u2581of": 6, + "\u2581France": 7, + "\u2581is": 8, + "\u2581Paris": 9, + "\u2581a": 10, + "a": 11, + "t": 12, + "h": 13, + "e": 14, + } + // Add byte fallback tokens for all 256 bytes. + nextID := 15 + for b := 0; b < 256; b++ { + tok := fmt.Sprintf("<0x%02X>", b) + vocab[tok] = nextID + nextID++ + } + + scores := make([]float32, nextID) + scores[0] = -100 // + scores[1] = -100 // + scores[2] = -100 // + scores[3] = -5.0 // ▁ + scores[4] = -1.0 // ▁the + scores[5] = -0.5 // ▁capital + scores[6] = -1.0 // ▁of + scores[7] = -0.5 // ▁France + scores[8] = -1.0 // ▁is + scores[9] = -0.5 // ▁Paris + scores[10] = -2.0 // ▁a + scores[11] = -4.0 // a + scores[12] = -4.0 // t + scores[13] = -4.0 // h + scores[14] = -4.0 // e + // Byte fallback tokens get very negative scores. + for i := 15; i < nextID; i++ { + scores[i] = -10.0 + } + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + return tok +} + +func TestSentencePieceUnigram_EncodeDecodeSentence(t *testing.T) { + tok := makeTestSentencePieceUnigramWithBytes() + + text := "The capital of France is Paris" + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + if len(ids) == 0 { + t.Fatalf("Encode(%q) produced empty result", text) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != text { + t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded) + } +} + +func TestSentencePieceUnigram_ByteFallback(t *testing.T) { + tok := makeTestSentencePieceUnigramWithBytes() + + // Encode a string with characters not directly in vocab. + // The emoji will require byte fallback via <0xNN> tokens. + text := "the \xc3\xa9" // "the é" — é is 0xC3 0xA9 in UTF-8 + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + if len(ids) == 0 { + t.Fatalf("Encode(%q) produced empty result", text) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != text { + t.Errorf("byte fallback round-trip: %q -> %v -> %q", text, ids, decoded) + } +} + +func TestSentencePieceUnigram_EmptyAndSingle(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // Empty string. + ids, err := tok.Encode("") + if err != nil { + t.Fatalf("Encode empty error: %v", err) + } + if len(ids) != 0 { + t.Errorf("Encode(\"\") = %v, want []", ids) + } + + // Single character that exists in vocab. + ids, err = tok.Encode("a") + if err != nil { + t.Fatalf("Encode(\"a\") error: %v", err) + } + if len(ids) == 0 { + t.Fatal("Encode(\"a\") produced empty result") + } +} + +func TestSentencePieceUnigram_LongText(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // Encode and decode a longer text with repeated words. + text := "the test is a test the test is a test" + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode error: %v", err) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != text { + t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded) + } +} + +func TestDecodeSentencePieceBytes(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"no byte tokens", "hello", "hello"}, + {"single byte", "<0x41>", "A"}, + {"multiple bytes", "<0xC3><0xA9>", "\xc3\xa9"}, // é + {"mixed", "hello<0x21>world", "hello!world"}, + {"invalid hex preserved", "<0xZZ>", "<0xZZ>"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := decodeSentencePieceBytes(tc.input) + if got != tc.want { + t.Errorf("decodeSentencePieceBytes(%q) = %q, want %q", tc.input, got, tc.want) + } + }) + } +}