Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,15 @@ func (t *BPETokenizer) sentencePieceEncode(text string) []int {
}
}
}
// Byte fallback: if no vocab token covers position i, use <0xNN>.
// Byte fallback: use <0xNN> as last resort when no vocab token covers
// position i. Byte tokens get a fixed penalty of -1e6 so they never
// beat real vocabulary tokens in the Viterbi DP. This matches
// llama.cpp / SentencePiece behavior where byte fallback is only
// used for characters that have no vocabulary coverage.
byteToken := fmt.Sprintf("<0x%02X>", text[i])
if id, ok := t.vocab[byteToken]; ok {
score := bestScore[i] + float64(t.tokenScore(id))
_ = id // byte token exists but we ignore its vocab score
score := bestScore[i] + (-1e6)
if score > bestScore[i+1] {
bestScore[i+1] = score
bestLen[i+1] = 1
Expand Down
139 changes: 139 additions & 0 deletions bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,145 @@ func TestSentencePieceUnigram_LongText(t *testing.T) {
}
}

func TestSentencePieceUnigram_ByteFallbackNeverBeatsVocab(t *testing.T) {
// Regression test: byte fallback tokens must never be preferred over
// multi-character vocab tokens, even when byte token scores are higher.
// This was the original bug — byte tokens like <0xE2> had scores of 0.0
// which beat multi-character tokens with negative scores, producing 43
// byte-level tokens instead of 7 word tokens.
vocab := map[string]int{
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"\u2581": 3,
"\u2581What": 4,
"\u2581is": 5,
"\u2581the": 6,
"\u2581capital": 7,
"\u2581of": 8,
"\u2581France": 9,
"?": 10,
}
// Add byte fallback tokens for all 256 bytes.
nextID := 11
for b := 0; b < 256; b++ {
tok := fmt.Sprintf("<0x%02X>", b)
vocab[tok] = nextID
nextID++
}

scores := make([]float32, nextID)
scores[0] = -100 // <unk>
scores[1] = -100 // <s>
scores[2] = -100 // </s>
scores[3] = -5.0 // ▁
scores[4] = -8.0 // ▁What
scores[5] = -7.0 // ▁is
scores[6] = -6.0 // ▁the
scores[7] = -9.0 // ▁capital
scores[8] = -6.0 // ▁of
scores[9] = -9.0 // ▁France
scores[10] = -4.0 // ?
// Byte fallback tokens get HIGH scores (the bug scenario).
// Before the fix, these would win over multi-character vocab tokens.
for i := 11; i < nextID; i++ {
scores[i] = 0.0
}

special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
tok := NewBPETokenizer(vocab, nil, special, false)
tok.SetSentencePiece(true)
tok.SetScores(scores)

ids, err := tok.Encode("What is the capital of France?")
if err != nil {
t.Fatalf("Encode error: %v", err)
}
// Must produce word-level tokens, not byte-level tokens.
// "What is the capital of France?" -> [▁What, ▁is, ▁the, ▁capital, ▁of, ▁France, ?]
want := []int{4, 5, 6, 7, 8, 9, 10}
if len(ids) != len(want) {
t.Fatalf("Encode produced %d tokens %v, want %d tokens %v", len(ids), ids, len(want), want)
}
for i, id := range ids {
if id != want[i] {
t.Errorf("[%d] = %d, want %d", i, id, want[i])
}
}

// Verify round-trip.
decoded, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode error: %v", err)
}
if decoded != "What is the capital of France?" {
t.Errorf("Decode = %q, want %q", decoded, "What is the capital of France?")
}
}

func TestSentencePieceUnigram_ByteFallbackStillWorksForUnknownChars(t *testing.T) {
// Byte fallback must still be used for characters that have no
// vocabulary coverage (e.g., emoji, rare Unicode).
vocab := map[string]int{
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"\u2581": 3,
"\u2581hi": 4,
}
nextID := 5
for b := 0; b < 256; b++ {
tok := fmt.Sprintf("<0x%02X>", b)
vocab[tok] = nextID
nextID++
}

scores := make([]float32, nextID)
scores[0] = -100
scores[1] = -100
scores[2] = -100
scores[3] = -5.0
scores[4] = -1.0 // ▁hi
for i := 5; i < nextID; i++ {
scores[i] = -2.0 // byte scores
}

special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
tok := NewBPETokenizer(vocab, nil, special, false)
tok.SetSentencePiece(true)
tok.SetScores(scores)

// "hi" has a vocab token; should use it.
ids, err := tok.Encode("hi")
if err != nil {
t.Fatalf("Encode(\"hi\") error: %v", err)
}
if len(ids) != 1 || ids[0] != 4 {
t.Errorf("Encode(\"hi\") = %v, want [4] (▁hi)", ids)
}

// "hi\xc3\xa9" — é (U+00E9) is not in vocab, must use byte fallback.
ids, err = tok.Encode("hi\xc3\xa9")
if err != nil {
t.Fatalf("Encode error: %v", err)
}
// Should be: ▁hi + <0xC3> + <0xA9>
if len(ids) != 3 {
t.Fatalf("Encode(\"hi\\xc3\\xa9\") = %v (len=%d), want 3 tokens", ids, len(ids))
}
if ids[0] != 4 {
t.Errorf("[0] = %d, want 4 (▁hi)", ids[0])
}
// Verify round-trip through decode.
decoded, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode error: %v", err)
}
if decoded != "hi\xc3\xa9" {
t.Errorf("Decode = %q, want %q", decoded, "hi\xc3\xa9")
}
}

func TestDecodeSentencePieceBytes(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading