From 4620fdd2d45e5e80c10a07f24d2129db7f0a9c0e Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Fri, 20 Mar 2026 20:06:45 -0700 Subject: [PATCH 1/2] feat(wordpiece): add WordPiece tokenizer for BERT-family models - WordPieceTokenizer implementing Tokenizer interface with greedy longest-prefix subword splitting and ## continuation tokens - EncodeForBERT method producing input_ids, attention_mask, and token_type_ids for single sentences and sentence pairs with padding - Pre-tokenization splitting on whitespace and punctuation boundaries - Load function in loader.go dispatching to BPE or WordPiece based on model.type in tokenizer.json - extractSpecialTokens recognizes BERT-style [CLS]/[SEP]/[PAD]/[UNK] - Comprehensive tests: encode, decode, round-trip, BERT format, padding, sentence pairs, pre-tokenization, loader integration --- loader.go | 66 ++++++++- wordpiece.go | 272 ++++++++++++++++++++++++++++++++++ wordpiece_test.go | 362 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 693 insertions(+), 7 deletions(-) create mode 100644 wordpiece.go create mode 100644 wordpiece_test.go diff --git a/loader.go b/loader.go index b6e05f8..8e9939e 100644 --- a/loader.go +++ b/loader.go @@ -20,9 +20,10 @@ type tokenizerJSON struct { } type modelJSON struct { - Type string `json:"type"` - Vocab map[string]int `json:"vocab"` - RawMerges json.RawMessage `json:"merges"` + Type string `json:"type"` + Vocab map[string]int `json:"vocab"` + RawMerges json.RawMessage `json:"merges"` + ContinuingSubwordPrefix string `json:"continuing_subword_prefix"` } type addedTokenJSON struct { @@ -52,7 +53,31 @@ type decoderPatternJSON struct { String string `json:"String"` } +// Load reads a HuggingFace tokenizer.json file and returns the appropriate +// Tokenizer implementation based on the model type (BPE or WordPiece). +func Load(path string) (Tokenizer, error) { + data, err := os.ReadFile(path) //nolint:gosec // user-provided path + if err != nil { + return nil, fmt.Errorf("read tokenizer.json: %w", err) + } + + var tj tokenizerJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, fmt.Errorf("parse tokenizer.json: %w", err) + } + + switch tj.Model.Type { + case "WordPiece": + return loadWordPiece(tj) + case "", "BPE": + return loadBPE(tj) + default: + return nil, fmt.Errorf("unsupported model type: %q (supported: BPE, WordPiece)", tj.Model.Type) + } +} + // LoadFromJSON reads a HuggingFace tokenizer.json file and returns a BPETokenizer. +// For loading any tokenizer type, use [Load] instead. func LoadFromJSON(path string) (*BPETokenizer, error) { data, err := os.ReadFile(path) //nolint:gosec // user-provided path if err != nil { @@ -68,6 +93,11 @@ func LoadFromJSON(path string) (*BPETokenizer, error) { return nil, fmt.Errorf("unsupported model type: %q (only BPE supported)", tj.Model.Type) } + return loadBPE(tj) +} + +// loadBPE constructs a BPETokenizer from parsed JSON. +func loadBPE(tj tokenizerJSON) (*BPETokenizer, error) { // Parse merges — supports both ["a b", …] and [["a","b"], …] formats. merges, err := parseMerges(tj.Model.RawMerges) if err != nil { @@ -94,6 +124,26 @@ func LoadFromJSON(path string) (*BPETokenizer, error) { return tok, nil } +// loadWordPiece constructs a WordPieceTokenizer from parsed JSON. +func loadWordPiece(tj tokenizerJSON) (*WordPieceTokenizer, error) { + special := extractSpecialTokens(tj.AddedTokens) + normalizer := buildNormalizer(tj.Normalizer) + + tok := NewWordPieceTokenizer(tj.Model.Vocab, special) + tok.normalizer = normalizer + + // Register special token strings for exact matching. + specialMap := make(map[string]int) + for _, at := range tj.AddedTokens { + if at.Special { + specialMap[at.Content] = at.ID + } + } + tok.specialTokens = specialMap + + return tok, nil +} + // isByteLevelPreTokenizer returns true if the pre-tokenizer config uses ByteLevel. func isByteLevelPreTokenizer(pt *preTokenizerJSON) bool { if pt == nil { @@ -136,6 +186,8 @@ func isSentencePieceDecoder(d *decoderJSON) bool { } // extractSpecialTokens finds BOS, EOS, PAD, UNK from added_tokens. +// Recognizes both GPT-style (, , , ) and BERT-style +// ([CLS], [SEP], [PAD], [UNK]) special token conventions. func extractSpecialTokens(tokens []addedTokenJSON) SpecialTokens { special := SpecialTokens{} for _, t := range tokens { @@ -143,13 +195,13 @@ func extractSpecialTokens(tokens []addedTokenJSON) SpecialTokens { continue } switch { - case strings.Contains(t.Content, "bos") || t.Content == "": + case strings.Contains(t.Content, "bos") || t.Content == "" || t.Content == "[CLS]": special.BOS = t.ID - case strings.Contains(t.Content, "eos") || t.Content == "": + case strings.Contains(t.Content, "eos") || t.Content == "" || t.Content == "[SEP]": special.EOS = t.ID - case strings.Contains(t.Content, "pad") || t.Content == "": + case strings.Contains(t.Content, "pad") || t.Content == "" || t.Content == "[PAD]": special.PAD = t.ID - case strings.Contains(t.Content, "unk") || t.Content == "": + case strings.Contains(t.Content, "unk") || t.Content == "" || t.Content == "[UNK]": special.UNK = t.ID } } diff --git a/wordpiece.go b/wordpiece.go new file mode 100644 index 0000000..52d686f --- /dev/null +++ b/wordpiece.go @@ -0,0 +1,272 @@ +package ztoken + +import ( + "fmt" + "strings" + "unicode" +) + +// WordPieceTokenizer implements the Tokenizer interface using the WordPiece +// algorithm, as used by BERT-family models. It greedily matches the longest +// subword prefix from the vocabulary, using "##" to denote continuation tokens. +// +// Stable. +type WordPieceTokenizer struct { + vocab map[string]int + reverseVocab map[int]string + special SpecialTokens + normalizer NormalizerFunc + // maxTokenLen is the length of the longest token in the vocabulary, + // used to bound the greedy prefix search. + maxTokenLen int + // unkToken is the string representation of the unknown token. + unkToken string + // specialTokens maps special token strings to IDs for exact matching. + specialTokens map[string]int +} + +// BERTEncoding holds the input tensors expected by BERT-family models. +type BERTEncoding struct { + InputIDs []int // Token IDs: [CLS] + tokens + [SEP] (+ tokens + [SEP] for pairs) + AttentionMask []int // 1 for real tokens, 0 for padding + TokenTypeIDs []int // 0 for first sentence, 1 for second sentence +} + +// NewWordPieceTokenizer creates a WordPieceTokenizer from a vocabulary and special tokens. +func NewWordPieceTokenizer(vocab map[string]int, special SpecialTokens) *WordPieceTokenizer { + reverseVocab := make(map[int]string, len(vocab)) + maxLen := 0 + for k, v := range vocab { + reverseVocab[v] = k + if len(k) > maxLen { + maxLen = len(k) + } + } + + unkToken := "[UNK]" + if tok, ok := reverseVocab[special.UNK]; ok { + unkToken = tok + } + + return &WordPieceTokenizer{ + vocab: vocab, + reverseVocab: reverseVocab, + special: special, + maxTokenLen: maxLen, + unkToken: unkToken, + } +} + +// Encode tokenizes text into a sequence of token IDs using WordPiece. +func (t *WordPieceTokenizer) Encode(text string) ([]int, error) { + if text == "" { + return []int{}, nil + } + if t.normalizer != nil { + text = t.normalizer(text) + } + words := preTokenize(text) + var ids []int + for _, word := range words { + wordIDs := t.tokenizeWord(word) + ids = append(ids, wordIDs...) + } + return ids, nil +} + +// Decode converts token IDs back to text. Continuation tokens (##prefix) are +// joined without spaces to reconstruct words. +func (t *WordPieceTokenizer) Decode(ids []int) (string, error) { + var sb strings.Builder + for i, id := range ids { + tok, ok := t.reverseVocab[id] + if !ok { + return "", fmt.Errorf("unknown token ID: %d", id) + } + // Skip special tokens in decode output. + if t.isSpecialToken(tok) { + continue + } + if strings.HasPrefix(tok, "##") { + sb.WriteString(tok[2:]) + } else { + if i > 0 && sb.Len() > 0 { + sb.WriteByte(' ') + } + sb.WriteString(tok) + } + } + return sb.String(), nil +} + +// VocabSize returns the number of tokens in the vocabulary. +func (t *WordPieceTokenizer) VocabSize() int { + return len(t.vocab) +} + +// GetToken returns the string token for a given ID. +func (t *WordPieceTokenizer) GetToken(id int) (string, bool) { + tok, ok := t.reverseVocab[id] + return tok, ok +} + +// GetID returns the token ID for a given string. +func (t *WordPieceTokenizer) GetID(token string) (int, bool) { + id, ok := t.vocab[token] + return id, ok +} + +// SpecialTokens returns the special token IDs. +func (t *WordPieceTokenizer) SpecialTokens() SpecialTokens { + return t.special +} + +// EncodeForBERT tokenizes one or two sentences into the BERT input format. +// For a single sentence: [CLS] tokens [SEP] +// For a sentence pair: [CLS] tokens_a [SEP] tokens_b [SEP] +// The result is padded to maxLen if maxLen > 0. +func (t *WordPieceTokenizer) EncodeForBERT(textA string, textB string, maxLen int) (*BERTEncoding, error) { + idsA, err := t.Encode(textA) + if err != nil { + return nil, fmt.Errorf("encode text_a: %w", err) + } + + clsID, ok := t.vocab["[CLS]"] + if !ok { + return nil, fmt.Errorf("vocabulary missing [CLS] token") + } + sepID, ok := t.vocab["[SEP]"] + if !ok { + return nil, fmt.Errorf("vocabulary missing [SEP] token") + } + + // Build input_ids: [CLS] + tokens_a + [SEP] + inputIDs := make([]int, 0, len(idsA)+3) + inputIDs = append(inputIDs, clsID) + inputIDs = append(inputIDs, idsA...) + inputIDs = append(inputIDs, sepID) + + // token_type_ids: 0 for first sentence segment + tokenTypeIDs := make([]int, len(inputIDs)) + + if textB != "" { + idsB, err := t.Encode(textB) + if err != nil { + return nil, fmt.Errorf("encode text_b: %w", err) + } + secondStart := len(inputIDs) + inputIDs = append(inputIDs, idsB...) + inputIDs = append(inputIDs, sepID) + // Extend token_type_ids: 1 for second sentence segment + for range len(idsB) + 1 { + tokenTypeIDs = append(tokenTypeIDs, 1) + } + _ = secondStart + } + + seqLen := len(inputIDs) + + // Pad if maxLen specified. + if maxLen > 0 && seqLen < maxLen { + padCount := maxLen - seqLen + for range padCount { + inputIDs = append(inputIDs, t.special.PAD) + tokenTypeIDs = append(tokenTypeIDs, 0) + } + } + + // Build attention_mask: 1 for real tokens, 0 for padding. + attentionMask := make([]int, len(inputIDs)) + for i := range seqLen { + attentionMask[i] = 1 + } + + return &BERTEncoding{ + InputIDs: inputIDs, + AttentionMask: attentionMask, + TokenTypeIDs: tokenTypeIDs, + }, nil +} + +// tokenizeWord applies the WordPiece algorithm to a single pre-tokenized word. +// It greedily matches the longest prefix in the vocabulary, continuing with +// "##"-prefixed subwords for the remainder. +func (t *WordPieceTokenizer) tokenizeWord(word string) []int { + if _, ok := t.vocab[word]; ok { + return []int{t.vocab[word]} + } + + var ids []int + start := 0 + runes := []rune(word) + + for start < len(runes) { + end := len(runes) + if end-start > t.maxTokenLen { + end = start + t.maxTokenLen + } + matched := false + for end > start { + substr := string(runes[start:end]) + if start > 0 { + substr = "##" + substr + } + if id, ok := t.vocab[substr]; ok { + ids = append(ids, id) + start = end + matched = true + break + } + end-- + } + if !matched { + // No subword match found — entire remaining word is UNK. + return []int{t.special.UNK} + } + } + return ids +} + +// isSpecialToken returns true if the token string is a known special token. +func (t *WordPieceTokenizer) isSpecialToken(tok string) bool { + switch tok { + case "[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]": + return true + } + if _, ok := t.specialTokens[tok]; ok { + return true + } + return false +} + +// preTokenize splits text on whitespace and punctuation boundaries, +// producing individual words and punctuation characters as separate tokens. +func preTokenize(text string) []string { + var tokens []string + var current strings.Builder + for _, r := range text { + if unicode.IsSpace(r) { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + continue + } + if unicode.IsPunct(r) { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + tokens = append(tokens, string(r)) + continue + } + current.WriteRune(r) + } + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + return tokens +} + +// Statically assert WordPieceTokenizer implements Tokenizer. +var _ Tokenizer = (*WordPieceTokenizer)(nil) diff --git a/wordpiece_test.go b/wordpiece_test.go new file mode 100644 index 0000000..6923db2 --- /dev/null +++ b/wordpiece_test.go @@ -0,0 +1,362 @@ +package ztoken + +import ( + "os" + "path/filepath" + "testing" +) + +func testWordPieceVocab() map[string]int { + return map[string]int{ + "[PAD]": 0, + "[UNK]": 1, + "[CLS]": 2, + "[SEP]": 3, + "[MASK]": 4, + "hello": 5, + "world": 6, + "un": 7, + "##aff": 8, + "##able": 9, + "the": 10, + "##s": 11, + "cat": 12, + "dog": 13, + "play": 14, + "##ing": 15, + "a": 16, + ",": 17, + ".": 18, + } +} + +func testWordPieceTokenizer() *WordPieceTokenizer { + vocab := testWordPieceVocab() + special := SpecialTokens{ + BOS: 2, // [CLS] + EOS: 3, // [SEP] + PAD: 0, // [PAD] + UNK: 1, // [UNK] + } + return NewWordPieceTokenizer(vocab, special) +} + +func TestWordPieceTokenizer_Encode(t *testing.T) { + tok := testWordPieceTokenizer() + + tests := []struct { + name string + input string + want []int + }{ + {"single word", "hello", []int{5}}, + {"two words", "hello world", []int{5, 6}}, + {"subword split", "unaffable", []int{7, 8, 9}}, + {"unknown word", "xyzzy", []int{1}}, + {"empty string", "", []int{}}, + {"punctuation split", "hello,world", []int{5, 17, 6}}, + {"continuation suffix", "cats", []int{12, 11}}, + {"playing", "playing", []int{14, 15}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ids, err := tok.Encode(tt.input) + if err != nil { + t.Fatalf("Encode(%q) error: %v", tt.input, err) + } + if len(ids) != len(tt.want) { + t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", + tt.input, ids, len(ids), tt.want, len(tt.want)) + } + for i, id := range ids { + if id != tt.want[i] { + t.Errorf("Encode(%q)[%d] = %d, want %d", tt.input, i, id, tt.want[i]) + } + } + }) + } +} + +func TestWordPieceTokenizer_Decode(t *testing.T) { + tok := testWordPieceTokenizer() + + tests := []struct { + name string + ids []int + want string + }{ + {"single word", []int{5}, "hello"}, + {"two words", []int{5, 6}, "hello world"}, + {"subword joined", []int{7, 8, 9}, "unaffable"}, + {"with continuation", []int{12, 11}, "cats"}, + {"skip CLS/SEP", []int{2, 5, 3}, "hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoded, err := tok.Decode(tt.ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", tt.ids, err) + } + if decoded != tt.want { + t.Errorf("Decode(%v) = %q, want %q", tt.ids, decoded, tt.want) + } + }) + } +} + +func TestWordPieceTokenizer_RoundTrip(t *testing.T) { + tok := testWordPieceTokenizer() + + texts := []string{ + "hello world", + "the cat", + "playing", + "cats", + "unaffable", + } + + for _, text := range texts { + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", ids, err) + } + if decoded != text { + t.Errorf("round-trip: %q -> %v -> %q", text, ids, decoded) + } + } +} + +func TestWordPieceTokenizer_VocabSize(t *testing.T) { + tok := testWordPieceTokenizer() + if got := tok.VocabSize(); got != 19 { + t.Errorf("VocabSize() = %d, want 19", got) + } +} + +func TestWordPieceTokenizer_GetToken(t *testing.T) { + tok := testWordPieceTokenizer() + tok1, ok := tok.GetToken(5) + if !ok || tok1 != "hello" { + t.Errorf("GetToken(5) = (%q, %v), want (\"hello\", true)", tok1, ok) + } + _, ok = tok.GetToken(999) + if ok { + t.Error("GetToken(999) should return false") + } +} + +func TestWordPieceTokenizer_GetID(t *testing.T) { + tok := testWordPieceTokenizer() + id, ok := tok.GetID("hello") + if !ok || id != 5 { + t.Errorf("GetID(\"hello\") = (%d, %v), want (5, true)", id, ok) + } + _, ok = tok.GetID("nonexistent") + if ok { + t.Error("GetID(\"nonexistent\") should return false") + } +} + +func TestWordPieceTokenizer_SpecialTokens(t *testing.T) { + tok := testWordPieceTokenizer() + sp := tok.SpecialTokens() + if sp.BOS != 2 { + t.Errorf("BOS = %d, want 2", sp.BOS) + } + if sp.EOS != 3 { + t.Errorf("EOS = %d, want 3", sp.EOS) + } + if sp.PAD != 0 { + t.Errorf("PAD = %d, want 0", sp.PAD) + } + if sp.UNK != 1 { + t.Errorf("UNK = %d, want 1", sp.UNK) + } +} + +func TestWordPieceTokenizer_EncodeForBERT_Single(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello world", "", 0) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, world=6, [SEP]=3 + wantIDs := []int{2, 5, 6, 3} + wantMask := []int{1, 1, 1, 1} + wantTypes := []int{0, 0, 0, 0} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_EncodeForBERT_Pair(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello", "world", 0) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, [SEP]=3, world=6, [SEP]=3 + wantIDs := []int{2, 5, 3, 6, 3} + wantMask := []int{1, 1, 1, 1, 1} + wantTypes := []int{0, 0, 0, 1, 1} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_EncodeForBERT_Padding(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello", "", 8) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, [SEP]=3, [PAD]=0 x5 + wantIDs := []int{2, 5, 3, 0, 0, 0, 0, 0} + wantMask := []int{1, 1, 1, 0, 0, 0, 0, 0} + wantTypes := []int{0, 0, 0, 0, 0, 0, 0, 0} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_PreTokenize(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"hello world", []string{"hello", "world"}}, + {"hello,world", []string{"hello", ",", "world"}}, + {"hello. world!", []string{"hello", ".", "world", "!"}}, + {" spaces ", []string{"spaces"}}, + {"", nil}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := preTokenize(tt.input) + if len(got) != len(tt.want) { + t.Fatalf("preTokenize(%q) = %v, want %v", tt.input, got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("preTokenize(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestLoad_WordPiece(t *testing.T) { + fixture := `{ + "model": { + "type": "WordPiece", + "vocab": { + "[PAD]": 0, "[UNK]": 1, "[CLS]": 2, "[SEP]": 3, "[MASK]": 4, + "hello": 5, "world": 6, "un": 7, "##aff": 8, "##able": 9 + } + }, + "added_tokens": [ + {"id": 0, "content": "[PAD]", "special": true}, + {"id": 1, "content": "[UNK]", "special": true}, + {"id": 2, "content": "[CLS]", "special": true}, + {"id": 3, "content": "[SEP]", "special": true}, + {"id": 4, "content": "[MASK]", "special": true} + ], + "normalizer": { + "type": "Lowercase" + } +}` + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(fixture), 0o600); err != nil { + t.Fatal(err) + } + + tok, err := Load(path) + if err != nil { + t.Fatalf("Load error: %v", err) + } + + wp, ok := tok.(*WordPieceTokenizer) + if !ok { + t.Fatalf("expected *WordPieceTokenizer, got %T", tok) + } + + // Normalizer should lowercase. + ids, err := wp.Encode("HELLO") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + if len(ids) != 1 || ids[0] != 5 { + t.Errorf("Encode(\"HELLO\") = %v, want [5]", ids) + } + + // Subword splitting. + ids, err = wp.Encode("unaffable") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + wantIDs := []int{7, 8, 9} + assertIntSlice(t, "subword", ids, wantIDs) + + // Special tokens. + sp := tok.SpecialTokens() + if sp.BOS != 2 || sp.EOS != 3 || sp.PAD != 0 || sp.UNK != 1 { + t.Errorf("SpecialTokens = %+v", sp) + } +} + +func TestLoad_BPE(t *testing.T) { + // Load should still work for BPE models. + tok, err := Load(filepath.Join("testdata", "tokenizer.json")) + if err != nil { + t.Fatalf("Load error: %v", err) + } + if _, ok := tok.(*BPETokenizer); !ok { + t.Fatalf("expected *BPETokenizer, got %T", tok) + } +} + +func TestLoad_UnsupportedType(t *testing.T) { + fixture := `{ + "model": {"type": "Unigram", "vocab": {}}, + "added_tokens": [] +}` + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(fixture), 0o600); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for Unigram model type") + } +} + +func assertIntSlice(t *testing.T, name string, got, want []int) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("%s: len=%d, want len=%d; got %v, want %v", name, len(got), len(want), got, want) + } + for i := range got { + if got[i] != want[i] { + t.Errorf("%s[%d] = %d, want %d", name, i, got[i], want[i]) + } + } +} From 9905c983d3d37f2e214c6b7378216c3b1d8d801a Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 26 Mar 2026 09:20:02 -0700 Subject: [PATCH 2/2] feat: add SentencePiece unigram encoding for models without merges SentencePiece unigram models (e.g., Mistral 7B GGUF) provide vocabulary scores but no BPE merge table. Without this, encoding fails silently, producing wrong token IDs and garbage output. Add SetScores() to BPETokenizer and a greedy longest-match encoder that selects tokens by length first, then by score. When merges are empty but scores are present, encodeSegment automatically uses this path instead of BPE merging. Also extend the gguf.Metadata interface with GetFloat32Array and extract tokenizer.ggml.scores in ExtractTokenizer so GGUF-loaded tokenizers automatically use unigram encoding when appropriate. --- bpe.go | 115 +++++++++++++++++++++++++- bpe_test.go | 207 ++++++++++++++++++++++++++++++++++++++++++++++ gguf/gguf.go | 8 ++ gguf/gguf_test.go | 76 ++++++++++++++++- 4 files changed, 398 insertions(+), 8 deletions(-) diff --git a/bpe.go b/bpe.go index 9b394f7..9aa7a1f 100644 --- a/bpe.go +++ b/bpe.go @@ -3,6 +3,7 @@ package ztoken import ( "fmt" "strings" + "unicode/utf8" ) // MergePair represents an adjacent token pair used in BPE merging. @@ -15,6 +16,8 @@ type MergePair struct { // BPETokenizer implements the Tokenizer interface using byte-pair encoding. // It loads vocabulary and merge rules from HuggingFace tokenizer.json format. +// When scores are set and merges are empty, it falls back to SentencePiece +// unigram encoding using greedy longest-match with score-based selection. // // Stable. type BPETokenizer struct { @@ -38,6 +41,13 @@ type BPETokenizer struct { specialTokens map[string]int // normalizer is an optional text normalization function applied before tokenization. normalizer NormalizerFunc + // scores holds SentencePiece unigram scores (negative log probabilities) + // indexed by token ID. When scores are set and merges are empty, the + // tokenizer uses greedy longest-match encoding instead of BPE merging. + scores []float32 + // maxTokenLen caches the length (in bytes) of the longest token in vocab, + // used to bound the search window in sentencePieceEncode. + maxTokenLen int } // NewBPETokenizer creates a BPETokenizer from vocabulary, merge rules, and special tokens. @@ -94,13 +104,21 @@ func (t *BPETokenizer) encodeSegment(text string, addLeadingSpace bool) ([]int, } else { words = strings.Fields(text) } + // When merges are empty but scores are available, use SentencePiece + // unigram encoding (greedy longest-match) instead of BPE merging. + useUnigram := len(t.mergeRanks) == 0 && len(t.scores) > 0 + var ids []int for _, word := range words { - wordIDs, err := t.encodeWord(word) - if err != nil { - return nil, err + if useUnigram { + ids = append(ids, t.sentencePieceEncode(word)...) + } else { + wordIDs, err := t.encodeWord(word) + if err != nil { + return nil, err + } + ids = append(ids, wordIDs...) } - ids = append(ids, wordIDs...) } return ids, nil } @@ -226,6 +244,85 @@ func (t *BPETokenizer) SetSpecialTokenStrings(tokens map[string]int) { t.specialTokens = tokens } +// SetScores sets token scores for SentencePiece unigram encoding. +// When scores are set and merges are empty, the tokenizer uses +// score-based greedy encoding instead of BPE merge-based encoding. +// Scores are indexed by token ID (negative log probabilities). +func (t *BPETokenizer) SetScores(scores []float32) { + t.scores = scores + // Precompute max token length in bytes for search window bounding. + t.maxTokenLen = 0 + for tok := range t.vocab { + if len(tok) > t.maxTokenLen { + t.maxTokenLen = len(tok) + } + } +} + +// sentencePieceEncode tokenizes text using greedy longest-match with +// score-based selection. For each position, it finds all vocabulary tokens +// that match the input at that position, selects the longest match (breaking +// ties by highest score), and advances past the matched token. +// +// This is used for SentencePiece unigram models that provide vocabulary +// scores but no BPE merge table (e.g., Mistral 7B GGUF). +func (t *BPETokenizer) sentencePieceEncode(text string) []int { + if text == "" { + return nil + } + var ids []int + pos := 0 + textBytes := []byte(text) + n := len(textBytes) + + for pos < n { + bestLen := 0 + bestID := t.special.UNK + bestScore := float32(-1e30) + + // Search for the longest matching token at this position. + // Limit search window to maxTokenLen bytes. + maxEnd := pos + t.maxTokenLen + if maxEnd > n { + maxEnd = n + } + + for end := pos + 1; end <= maxEnd; end++ { + candidate := string(textBytes[pos:end]) + if id, ok := t.vocab[candidate]; ok { + candidateLen := end - pos + // Prefer longer matches. For equal length, prefer higher score. + if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) { + bestLen = candidateLen + bestID = id + bestScore = t.tokenScore(id) + } + } + } + + if bestLen == 0 { + // No matching token found; emit UNK and advance by one byte. + ids = append(ids, t.special.UNK) + // Advance past one UTF-8 character, not just one byte. + _, size := decodeRune(textBytes[pos:]) + pos += size + } else { + ids = append(ids, bestID) + pos += bestLen + } + } + return ids +} + +// tokenScore returns the score for a token ID, or 0 if scores are not set +// or the ID is out of range. +func (t *BPETokenizer) tokenScore(id int) float32 { + if id >= 0 && id < len(t.scores) { + return t.scores[id] + } + return 0 +} + // sentencePiecePreTokenize implements SentencePiece-style pre-tokenization. // Text is split on whitespace boundaries. Words that follow a space get ▁ // (U+2581) prepended. Newlines are emitted as separate tokens. @@ -404,5 +501,15 @@ func isPrintableGPT2Byte(b byte) bool { return false } +// decodeRune decodes the first UTF-8 rune from b and returns it with its byte length. +// If b is empty or invalid, it returns utf8.RuneError and 1 to ensure forward progress. +func decodeRune(b []byte) (rune, int) { + r, size := utf8.DecodeRune(b) + if size == 0 { + return utf8.RuneError, 1 + } + return r, size +} + // Statically assert BPETokenizer implements Tokenizer. var _ Tokenizer = (*BPETokenizer)(nil) diff --git a/bpe_test.go b/bpe_test.go index 92db3fe..38c6cdd 100644 --- a/bpe_test.go +++ b/bpe_test.go @@ -398,3 +398,210 @@ func TestBPETokenizer_ByteLevelBPE(t *testing.T) { t.Errorf("Decode(%v) = %q, want \"hi\"", ids, decoded) } } + +// makeTestSentencePieceUnigram creates a SentencePiece unigram tokenizer +// with vocabulary and scores but no merges, simulating Mistral 7B GGUF. +func makeTestSentencePieceUnigram() *BPETokenizer { + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, // ▁ + "\u2581H": 4, + "\u2581He": 5, + "\u2581Hel": 6, + "\u2581Hell": 7, + "\u2581Hello": 8, + "\u2581w": 9, + "\u2581wo": 10, + "\u2581wor": 11, + "\u2581worl": 12, + "\u2581world": 13, + "H": 14, + "e": 15, + "l": 16, + "o": 17, + "w": 18, + "r": 19, + "d": 20, + "\u2581the": 21, + "\u2581is": 22, + "\u2581a": 23, + "\u2581test": 24, + "t": 25, + "s": 26, + } + + // Scores: higher (less negative) = more likely. Longer tokens get better scores. + scores := make([]float32, 27) + scores[0] = -100 // + scores[1] = -100 // + scores[2] = -100 // + scores[3] = -5.0 // ▁ + scores[4] = -3.0 // ▁H + scores[5] = -2.5 // ▁He + scores[6] = -2.0 // ▁Hel + scores[7] = -1.5 // ▁Hell + scores[8] = -1.0 // ▁Hello (best for "Hello") + scores[9] = -3.0 // ▁w + scores[10] = -2.5 // ▁wo + scores[11] = -2.0 // ▁wor + scores[12] = -1.5 // ▁worl + scores[13] = -1.0 // ▁world (best for "world") + scores[14] = -4.0 // H + scores[15] = -4.0 // e + scores[16] = -4.0 // l + scores[17] = -4.0 // o + scores[18] = -4.0 // w + scores[19] = -4.0 // r + scores[20] = -4.0 // d + scores[21] = -1.0 // ▁the + scores[22] = -1.0 // ▁is + scores[23] = -1.5 // ▁a + scores[24] = -1.0 // ▁test + scores[25] = -4.0 // t + scores[26] = -4.0 // s + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + // No merges — this is a unigram model. + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + return tok +} + +func TestSentencePieceUnigram_Encode(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []struct { + name string + input string + wantIDs []int + }{ + {"single word", "Hello", []int{8}}, // ▁Hello + {"two words", "Hello world", []int{8, 13}}, // ▁Hello ▁world + {"sentence", "the world is a test", []int{21, 13, 22, 23, 24}}, // ▁the ▁world ▁is ▁a ▁test + {"empty string", "", []int{}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ids, err := tok.Encode(tc.input) + if err != nil { + t.Fatalf("Encode(%q) error: %v", tc.input, err) + } + if len(ids) != len(tc.wantIDs) { + t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", tc.input, ids, len(ids), tc.wantIDs, len(tc.wantIDs)) + } + for i, id := range ids { + if id != tc.wantIDs[i] { + t.Errorf("Encode(%q)[%d] = %d, want %d", tc.input, i, id, tc.wantIDs[i]) + } + } + }) + } +} + +func TestSentencePieceUnigram_Decode(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []struct { + name string + ids []int + wantText string + wantErr bool + }{ + {"single token", []int{8}, "Hello", false}, + {"multiple tokens", []int{8, 13}, "Hello world", false}, + {"empty", []int{}, "", false}, + {"unknown ID", []int{999}, "", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := tok.Decode(tc.ids) + if tc.wantErr { + if err == nil { + t.Fatalf("Decode(%v) expected error, got %q", tc.ids, got) + } + return + } + if err != nil { + t.Fatalf("Decode(%v) error: %v", tc.ids, err) + } + if got != tc.wantText { + t.Errorf("Decode(%v) = %q, want %q", tc.ids, got, tc.wantText) + } + }) + } +} + +func TestSentencePieceUnigram_RoundTrip(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []string{"Hello", "Hello world", "the world is a test"} + for _, text := range tests { + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", ids, err) + } + if decoded != text { + t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded) + } + } +} + +func TestSentencePieceUnigram_UnknownChars(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // Characters not in vocab should produce UNK tokens. + ids, err := tok.Encode("xyz") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab, + // the greedy matcher will match ▁ first, then x, y, z individually → all UNK. + for _, id := range ids { + if id != tok.special.UNK { + // Either ▁ (id=3) or UNK (id=0) are acceptable since ▁ is in vocab. + if id != 3 { + t.Errorf("expected UNK or ▁ token for unknown chars, got id=%d", id) + } + } + } +} + +func TestSentencePieceUnigram_PrefersLongestMatch(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // "Hello" should encode as one token ▁Hello (id=8), not ▁H + e + l + l + o. + ids, err := tok.Encode("Hello") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + if len(ids) != 1 { + t.Errorf("expected 1 token for 'Hello', got %d: %v", len(ids), ids) + } + if ids[0] != 8 { + t.Errorf("expected token id 8 (▁Hello), got %d", ids[0]) + } +} + +func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) { + // When merges ARE present, unigram encoding should NOT be used + // even if scores are also set. + tok := makeTestBPE() + tok.SetScores([]float32{0, 0, 0, 0}) // set scores but merges exist + ids, err := tok.Encode("hello") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // Should still use BPE merging, producing "hello" (id=17). + if len(ids) != 1 || ids[0] != 17 { + t.Errorf("with merges present, expected BPE encoding [17], got %v", ids) + } +} diff --git a/gguf/gguf.go b/gguf/gguf.go index d750af1..40e16bf 100644 --- a/gguf/gguf.go +++ b/gguf/gguf.go @@ -19,6 +19,7 @@ type Metadata interface { GetStringArray(key string) ([]string, bool) GetUint32(key string) (uint32, bool) GetInt32Array(key string) ([]int32, bool) + GetFloat32Array(key string) ([]float32, bool) } // ExtractTokenizer builds a BPETokenizer from GGUF metadata. GGUF files store @@ -72,6 +73,13 @@ func ExtractTokenizer(m Metadata) (*ztoken.BPETokenizer, error) { tok.SetSentencePiece(true) } + // Extract token scores for SentencePiece unigram models. When scores + // are present but merges are absent, the tokenizer uses greedy + // longest-match encoding instead of BPE merge-based encoding. + if scores, ok := m.GetFloat32Array("tokenizer.ggml.scores"); ok { + tok.SetScores(scores) + } + // Extract control/special tokens (token_type == 3) for exact matching // during encoding. Without this, tokens like would be // split into characters by BPE. diff --git a/gguf/gguf_test.go b/gguf/gguf_test.go index 4fcb6ba..d69223f 100644 --- a/gguf/gguf_test.go +++ b/gguf/gguf_test.go @@ -6,10 +6,11 @@ import ( // testMetadata implements Metadata for testing. type testMetadata struct { - strings map[string]string - stringArrays map[string][]string - uint32s map[string]uint32 - int32Arrays map[string][]int32 + strings map[string]string + stringArrays map[string][]string + uint32s map[string]uint32 + int32Arrays map[string][]int32 + float32Arrays map[string][]float32 } func (m *testMetadata) GetString(key string) (string, bool) { @@ -32,6 +33,14 @@ func (m *testMetadata) GetInt32Array(key string) ([]int32, bool) { return v, ok } +func (m *testMetadata) GetFloat32Array(key string) ([]float32, bool) { + if m.float32Arrays == nil { + return nil, false + } + v, ok := m.float32Arrays[key] + return v, ok +} + func TestExtractTokenizer(t *testing.T) { m := &testMetadata{ strings: map[string]string{}, @@ -144,6 +153,65 @@ func TestExtractTokenizer_ControlTokens(t *testing.T) { } } +func TestExtractTokenizer_SentencePieceUnigram(t *testing.T) { + // Simulate a Mistral 7B GGUF: llama model with scores but no merges. + m := &testMetadata{ + strings: map[string]string{ + "tokenizer.ggml.model": "llama", + }, + stringArrays: map[string][]string{ + "tokenizer.ggml.tokens": { + "", "", "", + "\u2581Hello", "\u2581world", + "H", "e", "l", "o", "w", "r", "d", + }, + // No merges key at all. + }, + uint32s: map[string]uint32{ + "tokenizer.ggml.bos_token_id": 1, + "tokenizer.ggml.eos_token_id": 2, + "tokenizer.ggml.unknown_token_id": 0, + }, + int32Arrays: map[string][]int32{}, + float32Arrays: map[string][]float32{ + "tokenizer.ggml.scores": { + -100, -100, -100, // , , + -1.0, -1.0, // ▁Hello, ▁world (high score = preferred) + -5.0, -5.0, -5.0, -5.0, -5.0, -5.0, -5.0, // individual chars + }, + }, + } + + tok, err := ExtractTokenizer(m) + if err != nil { + t.Fatalf("ExtractTokenizer error: %v", err) + } + + // Encode "Hello world" should produce [▁Hello, ▁world] = [3, 4]. + ids, err := tok.Encode("Hello world") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + want := []int{3, 4} + if len(ids) != len(want) { + t.Fatalf("Encode(\"Hello world\") = %v (len=%d), want %v", ids, len(ids), want) + } + for i, id := range ids { + if id != want[i] { + t.Errorf("Encode[%d] = %d, want %d", i, id, want[i]) + } + } + + // Decode should round-trip. + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != "Hello world" { + t.Errorf("Decode = %q, want %q", decoded, "Hello world") + } +} + func TestExtractTokenizer_InvalidMerge(t *testing.T) { m := &testMetadata{ strings: map[string]string{},