Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ztoken
import (
"fmt"
"strings"
"unicode/utf8"
)

// MergePair represents an adjacent token pair used in BPE merging.
Expand All @@ -15,6 +16,8 @@ type MergePair struct {

// BPETokenizer implements the Tokenizer interface using byte-pair encoding.
// It loads vocabulary and merge rules from HuggingFace tokenizer.json format.
// When scores are set and merges are empty, it falls back to SentencePiece
// unigram encoding using greedy longest-match with score-based selection.
//
// Stable.
type BPETokenizer struct {
Expand All @@ -38,6 +41,13 @@ type BPETokenizer struct {
specialTokens map[string]int
// normalizer is an optional text normalization function applied before tokenization.
normalizer NormalizerFunc
// scores holds SentencePiece unigram scores (negative log probabilities)
// indexed by token ID. When scores are set and merges are empty, the
// tokenizer uses greedy longest-match encoding instead of BPE merging.
scores []float32
// maxTokenLen caches the length (in bytes) of the longest token in vocab,
// used to bound the search window in sentencePieceEncode.
maxTokenLen int
}

// NewBPETokenizer creates a BPETokenizer from vocabulary, merge rules, and special tokens.
Expand Down Expand Up @@ -94,13 +104,21 @@ func (t *BPETokenizer) encodeSegment(text string, addLeadingSpace bool) ([]int,
} else {
words = strings.Fields(text)
}
// When merges are empty but scores are available, use SentencePiece
// unigram encoding (greedy longest-match) instead of BPE merging.
useUnigram := len(t.mergeRanks) == 0 && len(t.scores) > 0

var ids []int
for _, word := range words {
wordIDs, err := t.encodeWord(word)
if err != nil {
return nil, err
if useUnigram {
ids = append(ids, t.sentencePieceEncode(word)...)
} else {
wordIDs, err := t.encodeWord(word)
if err != nil {
return nil, err
}
ids = append(ids, wordIDs...)
}
ids = append(ids, wordIDs...)
}
return ids, nil
}
Expand Down Expand Up @@ -226,6 +244,85 @@ func (t *BPETokenizer) SetSpecialTokenStrings(tokens map[string]int) {
t.specialTokens = tokens
}

// SetScores sets token scores for SentencePiece unigram encoding.
// When scores are set and merges are empty, the tokenizer uses
// score-based greedy encoding instead of BPE merge-based encoding.
// Scores are indexed by token ID (negative log probabilities).
func (t *BPETokenizer) SetScores(scores []float32) {
t.scores = scores
// Precompute max token length in bytes for search window bounding.
t.maxTokenLen = 0
for tok := range t.vocab {
if len(tok) > t.maxTokenLen {
t.maxTokenLen = len(tok)
}
}
}

// sentencePieceEncode tokenizes text using greedy longest-match with
// score-based selection. For each position, it finds all vocabulary tokens
// that match the input at that position, selects the longest match (breaking
// ties by highest score), and advances past the matched token.
//
// This is used for SentencePiece unigram models that provide vocabulary
// scores but no BPE merge table (e.g., Mistral 7B GGUF).
func (t *BPETokenizer) sentencePieceEncode(text string) []int {
if text == "" {
return nil
}
var ids []int
pos := 0
textBytes := []byte(text)
n := len(textBytes)

for pos < n {
bestLen := 0
bestID := t.special.UNK
bestScore := float32(-1e30)

// Search for the longest matching token at this position.
// Limit search window to maxTokenLen bytes.
maxEnd := pos + t.maxTokenLen
if maxEnd > n {
maxEnd = n
}

for end := pos + 1; end <= maxEnd; end++ {
candidate := string(textBytes[pos:end])
if id, ok := t.vocab[candidate]; ok {
candidateLen := end - pos
// Prefer longer matches. For equal length, prefer higher score.
if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) {
bestLen = candidateLen
bestID = id
bestScore = t.tokenScore(id)
}
}
}

if bestLen == 0 {
// No matching token found; emit UNK and advance by one byte.
ids = append(ids, t.special.UNK)
// Advance past one UTF-8 character, not just one byte.
_, size := decodeRune(textBytes[pos:])
pos += size
} else {
ids = append(ids, bestID)
pos += bestLen
}
}
return ids
}

// tokenScore returns the score for a token ID, or 0 if scores are not set
// or the ID is out of range.
func (t *BPETokenizer) tokenScore(id int) float32 {
if id >= 0 && id < len(t.scores) {
return t.scores[id]
}
return 0
}

// sentencePiecePreTokenize implements SentencePiece-style pre-tokenization.
// Text is split on whitespace boundaries. Words that follow a space get ▁
// (U+2581) prepended. Newlines are emitted as separate tokens.
Expand Down Expand Up @@ -404,5 +501,15 @@ func isPrintableGPT2Byte(b byte) bool {
return false
}

// decodeRune decodes the first UTF-8 rune from b and returns it with its byte length.
// If b is empty or invalid, it returns utf8.RuneError and 1 to ensure forward progress.
func decodeRune(b []byte) (rune, int) {
r, size := utf8.DecodeRune(b)
if size == 0 {
return utf8.RuneError, 1
}
return r, size
}

// Statically assert BPETokenizer implements Tokenizer.
var _ Tokenizer = (*BPETokenizer)(nil)
207 changes: 207 additions & 0 deletions bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,210 @@ func TestBPETokenizer_ByteLevelBPE(t *testing.T) {
t.Errorf("Decode(%v) = %q, want \"hi\"", ids, decoded)
}
}

// makeTestSentencePieceUnigram creates a SentencePiece unigram tokenizer
// with vocabulary and scores but no merges, simulating Mistral 7B GGUF.
func makeTestSentencePieceUnigram() *BPETokenizer {
vocab := map[string]int{
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"\u2581": 3, // ▁
"\u2581H": 4,
"\u2581He": 5,
"\u2581Hel": 6,
"\u2581Hell": 7,
"\u2581Hello": 8,
"\u2581w": 9,
"\u2581wo": 10,
"\u2581wor": 11,
"\u2581worl": 12,
"\u2581world": 13,
"H": 14,
"e": 15,
"l": 16,
"o": 17,
"w": 18,
"r": 19,
"d": 20,
"\u2581the": 21,
"\u2581is": 22,
"\u2581a": 23,
"\u2581test": 24,
"t": 25,
"s": 26,
}

// Scores: higher (less negative) = more likely. Longer tokens get better scores.
scores := make([]float32, 27)
scores[0] = -100 // <unk>
scores[1] = -100 // <s>
scores[2] = -100 // </s>
scores[3] = -5.0 // ▁
scores[4] = -3.0 // ▁H
scores[5] = -2.5 // ▁He
scores[6] = -2.0 // ▁Hel
scores[7] = -1.5 // ▁Hell
scores[8] = -1.0 // ▁Hello (best for "Hello")
scores[9] = -3.0 // ▁w
scores[10] = -2.5 // ▁wo
scores[11] = -2.0 // ▁wor
scores[12] = -1.5 // ▁worl
scores[13] = -1.0 // ▁world (best for "world")
scores[14] = -4.0 // H
scores[15] = -4.0 // e
scores[16] = -4.0 // l
scores[17] = -4.0 // o
scores[18] = -4.0 // w
scores[19] = -4.0 // r
scores[20] = -4.0 // d
scores[21] = -1.0 // ▁the
scores[22] = -1.0 // ▁is
scores[23] = -1.5 // ▁a
scores[24] = -1.0 // ▁test
scores[25] = -4.0 // t
scores[26] = -4.0 // s

special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0}
// No merges — this is a unigram model.
tok := NewBPETokenizer(vocab, nil, special, false)
tok.SetSentencePiece(true)
tok.SetScores(scores)
return tok
}

func TestSentencePieceUnigram_Encode(t *testing.T) {
tok := makeTestSentencePieceUnigram()

tests := []struct {
name string
input string
wantIDs []int
}{
{"single word", "Hello", []int{8}}, // ▁Hello
{"two words", "Hello world", []int{8, 13}}, // ▁Hello ▁world
{"sentence", "the world is a test", []int{21, 13, 22, 23, 24}}, // ▁the ▁world ▁is ▁a ▁test
{"empty string", "", []int{}},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ids, err := tok.Encode(tc.input)
if err != nil {
t.Fatalf("Encode(%q) error: %v", tc.input, err)
}
if len(ids) != len(tc.wantIDs) {
t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", tc.input, ids, len(ids), tc.wantIDs, len(tc.wantIDs))
}
for i, id := range ids {
if id != tc.wantIDs[i] {
t.Errorf("Encode(%q)[%d] = %d, want %d", tc.input, i, id, tc.wantIDs[i])
}
}
})
}
}

func TestSentencePieceUnigram_Decode(t *testing.T) {
tok := makeTestSentencePieceUnigram()

tests := []struct {
name string
ids []int
wantText string
wantErr bool
}{
{"single token", []int{8}, "Hello", false},
{"multiple tokens", []int{8, 13}, "Hello world", false},
{"empty", []int{}, "", false},
{"unknown ID", []int{999}, "", true},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := tok.Decode(tc.ids)
if tc.wantErr {
if err == nil {
t.Fatalf("Decode(%v) expected error, got %q", tc.ids, got)
}
return
}
if err != nil {
t.Fatalf("Decode(%v) error: %v", tc.ids, err)
}
if got != tc.wantText {
t.Errorf("Decode(%v) = %q, want %q", tc.ids, got, tc.wantText)
}
})
}
}

func TestSentencePieceUnigram_RoundTrip(t *testing.T) {
tok := makeTestSentencePieceUnigram()

tests := []string{"Hello", "Hello world", "the world is a test"}
for _, text := range tests {
ids, err := tok.Encode(text)
if err != nil {
t.Fatalf("Encode(%q) error: %v", text, err)
}
decoded, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode(%v) error: %v", ids, err)
}
if decoded != text {
t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded)
}
}
}

func TestSentencePieceUnigram_UnknownChars(t *testing.T) {
tok := makeTestSentencePieceUnigram()

// Characters not in vocab should produce UNK tokens.
ids, err := tok.Encode("xyz")
if err != nil {
t.Fatalf("Encode error: %v", err)
}
// "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab,
// the greedy matcher will match ▁ first, then x, y, z individually → all UNK.
for _, id := range ids {
if id != tok.special.UNK {
// Either ▁ (id=3) or UNK (id=0) are acceptable since ▁ is in vocab.
if id != 3 {
t.Errorf("expected UNK or ▁ token for unknown chars, got id=%d", id)
}
}
}
}

func TestSentencePieceUnigram_PrefersLongestMatch(t *testing.T) {
tok := makeTestSentencePieceUnigram()

// "Hello" should encode as one token ▁Hello (id=8), not ▁H + e + l + l + o.
ids, err := tok.Encode("Hello")
if err != nil {
t.Fatalf("Encode error: %v", err)
}
if len(ids) != 1 {
t.Errorf("expected 1 token for 'Hello', got %d: %v", len(ids), ids)
}
if ids[0] != 8 {
t.Errorf("expected token id 8 (▁Hello), got %d", ids[0])
}
}

func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) {
// When merges ARE present, unigram encoding should NOT be used
// even if scores are also set.
tok := makeTestBPE()
tok.SetScores([]float32{0, 0, 0, 0}) // set scores but merges exist
ids, err := tok.Encode("hello")
if err != nil {
t.Fatalf("Encode error: %v", err)
}
// Should still use BPE merging, producing "hello" (id=17).
if len(ids) != 1 || ids[0] != 17 {
t.Errorf("with merges present, expected BPE encoding [17], got %v", ids)
}
}
8 changes: 8 additions & 0 deletions gguf/gguf.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Metadata interface {
GetStringArray(key string) ([]string, bool)
GetUint32(key string) (uint32, bool)
GetInt32Array(key string) ([]int32, bool)
GetFloat32Array(key string) ([]float32, bool)
}

// ExtractTokenizer builds a BPETokenizer from GGUF metadata. GGUF files store
Expand Down Expand Up @@ -72,6 +73,13 @@ func ExtractTokenizer(m Metadata) (*ztoken.BPETokenizer, error) {
tok.SetSentencePiece(true)
}

// Extract token scores for SentencePiece unigram models. When scores
// are present but merges are absent, the tokenizer uses greedy
// longest-match encoding instead of BPE merge-based encoding.
if scores, ok := m.GetFloat32Array("tokenizer.ggml.scores"); ok {
tok.SetScores(scores)
}

// Extract control/special tokens (token_type == 3) for exact matching
// during encoding. Without this, tokens like <start_of_turn> would be
// split into characters by BPE.
Expand Down
Loading
Loading