From f6da4ab1fc0c49026ae6aacb6fe5ed9ec1281892 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 26 Mar 2026 11:13:56 -0700 Subject: [PATCH] fix: use large penalty for byte fallback in SentencePiece Viterbi Byte fallback tokens (<0xNN>) were competing with multi-character vocab tokens in the Viterbi DP using their actual vocabulary scores. When byte token scores happened to be higher than vocab token scores, the Viterbi algorithm preferred 43 byte-level tokens over 7 word-level tokens. Fix: assign byte fallback tokens a fixed score of -1e6 instead of their vocabulary score, ensuring they are only used as a last resort when no vocab token covers a position. This matches llama.cpp behavior. --- bpe.go | 9 +++- bpe_test.go | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/bpe.go b/bpe.go index b43e058..e89f4ba 100644 --- a/bpe.go +++ b/bpe.go @@ -342,10 +342,15 @@ func (t *BPETokenizer) sentencePieceEncode(text string) []int { } } } - // Byte fallback: if no vocab token covers position i, use <0xNN>. + // Byte fallback: use <0xNN> as last resort when no vocab token covers + // position i. Byte tokens get a fixed penalty of -1e6 so they never + // beat real vocabulary tokens in the Viterbi DP. This matches + // llama.cpp / SentencePiece behavior where byte fallback is only + // used for characters that have no vocabulary coverage. byteToken := fmt.Sprintf("<0x%02X>", text[i]) if id, ok := t.vocab[byteToken]; ok { - score := bestScore[i] + float64(t.tokenScore(id)) + _ = id // byte token exists but we ignore its vocab score + score := bestScore[i] + (-1e6) if score > bestScore[i+1] { bestScore[i+1] = score bestLen[i+1] = 1 diff --git a/bpe_test.go b/bpe_test.go index 56fd881..aa61f0e 100644 --- a/bpe_test.go +++ b/bpe_test.go @@ -799,6 +799,145 @@ func TestSentencePieceUnigram_LongText(t *testing.T) { } } +func TestSentencePieceUnigram_ByteFallbackNeverBeatsVocab(t *testing.T) { + // Regression test: byte fallback tokens must never be preferred over + // multi-character vocab tokens, even when byte token scores are higher. + // This was the original bug — byte tokens like <0xE2> had scores of 0.0 + // which beat multi-character tokens with negative scores, producing 43 + // byte-level tokens instead of 7 word tokens. + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, + "\u2581What": 4, + "\u2581is": 5, + "\u2581the": 6, + "\u2581capital": 7, + "\u2581of": 8, + "\u2581France": 9, + "?": 10, + } + // Add byte fallback tokens for all 256 bytes. + nextID := 11 + for b := 0; b < 256; b++ { + tok := fmt.Sprintf("<0x%02X>", b) + vocab[tok] = nextID + nextID++ + } + + scores := make([]float32, nextID) + scores[0] = -100 // + scores[1] = -100 // + scores[2] = -100 // + scores[3] = -5.0 // ▁ + scores[4] = -8.0 // ▁What + scores[5] = -7.0 // ▁is + scores[6] = -6.0 // ▁the + scores[7] = -9.0 // ▁capital + scores[8] = -6.0 // ▁of + scores[9] = -9.0 // ▁France + scores[10] = -4.0 // ? + // Byte fallback tokens get HIGH scores (the bug scenario). + // Before the fix, these would win over multi-character vocab tokens. + for i := 11; i < nextID; i++ { + scores[i] = 0.0 + } + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + + ids, err := tok.Encode("What is the capital of France?") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // Must produce word-level tokens, not byte-level tokens. + // "What is the capital of France?" -> [▁What, ▁is, ▁the, ▁capital, ▁of, ▁France, ?] + want := []int{4, 5, 6, 7, 8, 9, 10} + if len(ids) != len(want) { + t.Fatalf("Encode produced %d tokens %v, want %d tokens %v", len(ids), ids, len(want), want) + } + for i, id := range ids { + if id != want[i] { + t.Errorf("[%d] = %d, want %d", i, id, want[i]) + } + } + + // Verify round-trip. + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != "What is the capital of France?" { + t.Errorf("Decode = %q, want %q", decoded, "What is the capital of France?") + } +} + +func TestSentencePieceUnigram_ByteFallbackStillWorksForUnknownChars(t *testing.T) { + // Byte fallback must still be used for characters that have no + // vocabulary coverage (e.g., emoji, rare Unicode). + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, + "\u2581hi": 4, + } + nextID := 5 + for b := 0; b < 256; b++ { + tok := fmt.Sprintf("<0x%02X>", b) + vocab[tok] = nextID + nextID++ + } + + scores := make([]float32, nextID) + scores[0] = -100 + scores[1] = -100 + scores[2] = -100 + scores[3] = -5.0 + scores[4] = -1.0 // ▁hi + for i := 5; i < nextID; i++ { + scores[i] = -2.0 // byte scores + } + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + + // "hi" has a vocab token; should use it. + ids, err := tok.Encode("hi") + if err != nil { + t.Fatalf("Encode(\"hi\") error: %v", err) + } + if len(ids) != 1 || ids[0] != 4 { + t.Errorf("Encode(\"hi\") = %v, want [4] (▁hi)", ids) + } + + // "hi\xc3\xa9" — é (U+00E9) is not in vocab, must use byte fallback. + ids, err = tok.Encode("hi\xc3\xa9") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // Should be: ▁hi + <0xC3> + <0xA9> + if len(ids) != 3 { + t.Fatalf("Encode(\"hi\\xc3\\xa9\") = %v (len=%d), want 3 tokens", ids, len(ids)) + } + if ids[0] != 4 { + t.Errorf("[0] = %d, want 4 (▁hi)", ids[0]) + } + // Verify round-trip through decode. + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != "hi\xc3\xa9" { + t.Errorf("Decode = %q, want %q", decoded, "hi\xc3\xa9") + } +} + func TestDecodeSentencePieceBytes(t *testing.T) { tests := []struct { name string