diff --git a/bpe.go b/bpe.go index 9b394f7..9aa7a1f 100644 --- a/bpe.go +++ b/bpe.go @@ -3,6 +3,7 @@ package ztoken import ( "fmt" "strings" + "unicode/utf8" ) // MergePair represents an adjacent token pair used in BPE merging. @@ -15,6 +16,8 @@ type MergePair struct { // BPETokenizer implements the Tokenizer interface using byte-pair encoding. // It loads vocabulary and merge rules from HuggingFace tokenizer.json format. +// When scores are set and merges are empty, it falls back to SentencePiece +// unigram encoding using greedy longest-match with score-based selection. // // Stable. type BPETokenizer struct { @@ -38,6 +41,13 @@ type BPETokenizer struct { specialTokens map[string]int // normalizer is an optional text normalization function applied before tokenization. normalizer NormalizerFunc + // scores holds SentencePiece unigram scores (negative log probabilities) + // indexed by token ID. When scores are set and merges are empty, the + // tokenizer uses greedy longest-match encoding instead of BPE merging. + scores []float32 + // maxTokenLen caches the length (in bytes) of the longest token in vocab, + // used to bound the search window in sentencePieceEncode. + maxTokenLen int } // NewBPETokenizer creates a BPETokenizer from vocabulary, merge rules, and special tokens. @@ -94,13 +104,21 @@ func (t *BPETokenizer) encodeSegment(text string, addLeadingSpace bool) ([]int, } else { words = strings.Fields(text) } + // When merges are empty but scores are available, use SentencePiece + // unigram encoding (greedy longest-match) instead of BPE merging. + useUnigram := len(t.mergeRanks) == 0 && len(t.scores) > 0 + var ids []int for _, word := range words { - wordIDs, err := t.encodeWord(word) - if err != nil { - return nil, err + if useUnigram { + ids = append(ids, t.sentencePieceEncode(word)...) + } else { + wordIDs, err := t.encodeWord(word) + if err != nil { + return nil, err + } + ids = append(ids, wordIDs...) } - ids = append(ids, wordIDs...) } return ids, nil } @@ -226,6 +244,85 @@ func (t *BPETokenizer) SetSpecialTokenStrings(tokens map[string]int) { t.specialTokens = tokens } +// SetScores sets token scores for SentencePiece unigram encoding. +// When scores are set and merges are empty, the tokenizer uses +// score-based greedy encoding instead of BPE merge-based encoding. +// Scores are indexed by token ID (negative log probabilities). +func (t *BPETokenizer) SetScores(scores []float32) { + t.scores = scores + // Precompute max token length in bytes for search window bounding. + t.maxTokenLen = 0 + for tok := range t.vocab { + if len(tok) > t.maxTokenLen { + t.maxTokenLen = len(tok) + } + } +} + +// sentencePieceEncode tokenizes text using greedy longest-match with +// score-based selection. For each position, it finds all vocabulary tokens +// that match the input at that position, selects the longest match (breaking +// ties by highest score), and advances past the matched token. +// +// This is used for SentencePiece unigram models that provide vocabulary +// scores but no BPE merge table (e.g., Mistral 7B GGUF). +func (t *BPETokenizer) sentencePieceEncode(text string) []int { + if text == "" { + return nil + } + var ids []int + pos := 0 + textBytes := []byte(text) + n := len(textBytes) + + for pos < n { + bestLen := 0 + bestID := t.special.UNK + bestScore := float32(-1e30) + + // Search for the longest matching token at this position. + // Limit search window to maxTokenLen bytes. + maxEnd := pos + t.maxTokenLen + if maxEnd > n { + maxEnd = n + } + + for end := pos + 1; end <= maxEnd; end++ { + candidate := string(textBytes[pos:end]) + if id, ok := t.vocab[candidate]; ok { + candidateLen := end - pos + // Prefer longer matches. For equal length, prefer higher score. + if candidateLen > bestLen || (candidateLen == bestLen && t.tokenScore(id) > bestScore) { + bestLen = candidateLen + bestID = id + bestScore = t.tokenScore(id) + } + } + } + + if bestLen == 0 { + // No matching token found; emit UNK and advance by one byte. + ids = append(ids, t.special.UNK) + // Advance past one UTF-8 character, not just one byte. + _, size := decodeRune(textBytes[pos:]) + pos += size + } else { + ids = append(ids, bestID) + pos += bestLen + } + } + return ids +} + +// tokenScore returns the score for a token ID, or 0 if scores are not set +// or the ID is out of range. +func (t *BPETokenizer) tokenScore(id int) float32 { + if id >= 0 && id < len(t.scores) { + return t.scores[id] + } + return 0 +} + // sentencePiecePreTokenize implements SentencePiece-style pre-tokenization. // Text is split on whitespace boundaries. Words that follow a space get ▁ // (U+2581) prepended. Newlines are emitted as separate tokens. @@ -404,5 +501,15 @@ func isPrintableGPT2Byte(b byte) bool { return false } +// decodeRune decodes the first UTF-8 rune from b and returns it with its byte length. +// If b is empty or invalid, it returns utf8.RuneError and 1 to ensure forward progress. +func decodeRune(b []byte) (rune, int) { + r, size := utf8.DecodeRune(b) + if size == 0 { + return utf8.RuneError, 1 + } + return r, size +} + // Statically assert BPETokenizer implements Tokenizer. var _ Tokenizer = (*BPETokenizer)(nil) diff --git a/bpe_test.go b/bpe_test.go index 92db3fe..38c6cdd 100644 --- a/bpe_test.go +++ b/bpe_test.go @@ -398,3 +398,210 @@ func TestBPETokenizer_ByteLevelBPE(t *testing.T) { t.Errorf("Decode(%v) = %q, want \"hi\"", ids, decoded) } } + +// makeTestSentencePieceUnigram creates a SentencePiece unigram tokenizer +// with vocabulary and scores but no merges, simulating Mistral 7B GGUF. +func makeTestSentencePieceUnigram() *BPETokenizer { + vocab := map[string]int{ + "": 0, + "": 1, + "": 2, + "\u2581": 3, // ▁ + "\u2581H": 4, + "\u2581He": 5, + "\u2581Hel": 6, + "\u2581Hell": 7, + "\u2581Hello": 8, + "\u2581w": 9, + "\u2581wo": 10, + "\u2581wor": 11, + "\u2581worl": 12, + "\u2581world": 13, + "H": 14, + "e": 15, + "l": 16, + "o": 17, + "w": 18, + "r": 19, + "d": 20, + "\u2581the": 21, + "\u2581is": 22, + "\u2581a": 23, + "\u2581test": 24, + "t": 25, + "s": 26, + } + + // Scores: higher (less negative) = more likely. Longer tokens get better scores. + scores := make([]float32, 27) + scores[0] = -100 // + scores[1] = -100 // + scores[2] = -100 // + scores[3] = -5.0 // ▁ + scores[4] = -3.0 // ▁H + scores[5] = -2.5 // ▁He + scores[6] = -2.0 // ▁Hel + scores[7] = -1.5 // ▁Hell + scores[8] = -1.0 // ▁Hello (best for "Hello") + scores[9] = -3.0 // ▁w + scores[10] = -2.5 // ▁wo + scores[11] = -2.0 // ▁wor + scores[12] = -1.5 // ▁worl + scores[13] = -1.0 // ▁world (best for "world") + scores[14] = -4.0 // H + scores[15] = -4.0 // e + scores[16] = -4.0 // l + scores[17] = -4.0 // o + scores[18] = -4.0 // w + scores[19] = -4.0 // r + scores[20] = -4.0 // d + scores[21] = -1.0 // ▁the + scores[22] = -1.0 // ▁is + scores[23] = -1.5 // ▁a + scores[24] = -1.0 // ▁test + scores[25] = -4.0 // t + scores[26] = -4.0 // s + + special := SpecialTokens{BOS: 1, EOS: 2, PAD: 0, UNK: 0} + // No merges — this is a unigram model. + tok := NewBPETokenizer(vocab, nil, special, false) + tok.SetSentencePiece(true) + tok.SetScores(scores) + return tok +} + +func TestSentencePieceUnigram_Encode(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []struct { + name string + input string + wantIDs []int + }{ + {"single word", "Hello", []int{8}}, // ▁Hello + {"two words", "Hello world", []int{8, 13}}, // ▁Hello ▁world + {"sentence", "the world is a test", []int{21, 13, 22, 23, 24}}, // ▁the ▁world ▁is ▁a ▁test + {"empty string", "", []int{}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ids, err := tok.Encode(tc.input) + if err != nil { + t.Fatalf("Encode(%q) error: %v", tc.input, err) + } + if len(ids) != len(tc.wantIDs) { + t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", tc.input, ids, len(ids), tc.wantIDs, len(tc.wantIDs)) + } + for i, id := range ids { + if id != tc.wantIDs[i] { + t.Errorf("Encode(%q)[%d] = %d, want %d", tc.input, i, id, tc.wantIDs[i]) + } + } + }) + } +} + +func TestSentencePieceUnigram_Decode(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []struct { + name string + ids []int + wantText string + wantErr bool + }{ + {"single token", []int{8}, "Hello", false}, + {"multiple tokens", []int{8, 13}, "Hello world", false}, + {"empty", []int{}, "", false}, + {"unknown ID", []int{999}, "", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := tok.Decode(tc.ids) + if tc.wantErr { + if err == nil { + t.Fatalf("Decode(%v) expected error, got %q", tc.ids, got) + } + return + } + if err != nil { + t.Fatalf("Decode(%v) error: %v", tc.ids, err) + } + if got != tc.wantText { + t.Errorf("Decode(%v) = %q, want %q", tc.ids, got, tc.wantText) + } + }) + } +} + +func TestSentencePieceUnigram_RoundTrip(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + tests := []string{"Hello", "Hello world", "the world is a test"} + for _, text := range tests { + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", ids, err) + } + if decoded != text { + t.Errorf("round-trip failed: %q -> %v -> %q", text, ids, decoded) + } + } +} + +func TestSentencePieceUnigram_UnknownChars(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // Characters not in vocab should produce UNK tokens. + ids, err := tok.Encode("xyz") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // "xyz" -> pre-tokenized as "▁xyz". Since ▁x, ▁y, ▁z are not in vocab, + // the greedy matcher will match ▁ first, then x, y, z individually → all UNK. + for _, id := range ids { + if id != tok.special.UNK { + // Either ▁ (id=3) or UNK (id=0) are acceptable since ▁ is in vocab. + if id != 3 { + t.Errorf("expected UNK or ▁ token for unknown chars, got id=%d", id) + } + } + } +} + +func TestSentencePieceUnigram_PrefersLongestMatch(t *testing.T) { + tok := makeTestSentencePieceUnigram() + + // "Hello" should encode as one token ▁Hello (id=8), not ▁H + e + l + l + o. + ids, err := tok.Encode("Hello") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + if len(ids) != 1 { + t.Errorf("expected 1 token for 'Hello', got %d: %v", len(ids), ids) + } + if ids[0] != 8 { + t.Errorf("expected token id 8 (▁Hello), got %d", ids[0]) + } +} + +func TestSentencePieceUnigram_WithBPEFallback(t *testing.T) { + // When merges ARE present, unigram encoding should NOT be used + // even if scores are also set. + tok := makeTestBPE() + tok.SetScores([]float32{0, 0, 0, 0}) // set scores but merges exist + ids, err := tok.Encode("hello") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + // Should still use BPE merging, producing "hello" (id=17). + if len(ids) != 1 || ids[0] != 17 { + t.Errorf("with merges present, expected BPE encoding [17], got %v", ids) + } +} diff --git a/gguf/gguf.go b/gguf/gguf.go index d750af1..40e16bf 100644 --- a/gguf/gguf.go +++ b/gguf/gguf.go @@ -19,6 +19,7 @@ type Metadata interface { GetStringArray(key string) ([]string, bool) GetUint32(key string) (uint32, bool) GetInt32Array(key string) ([]int32, bool) + GetFloat32Array(key string) ([]float32, bool) } // ExtractTokenizer builds a BPETokenizer from GGUF metadata. GGUF files store @@ -72,6 +73,13 @@ func ExtractTokenizer(m Metadata) (*ztoken.BPETokenizer, error) { tok.SetSentencePiece(true) } + // Extract token scores for SentencePiece unigram models. When scores + // are present but merges are absent, the tokenizer uses greedy + // longest-match encoding instead of BPE merge-based encoding. + if scores, ok := m.GetFloat32Array("tokenizer.ggml.scores"); ok { + tok.SetScores(scores) + } + // Extract control/special tokens (token_type == 3) for exact matching // during encoding. Without this, tokens like would be // split into characters by BPE. diff --git a/gguf/gguf_test.go b/gguf/gguf_test.go index 4fcb6ba..d69223f 100644 --- a/gguf/gguf_test.go +++ b/gguf/gguf_test.go @@ -6,10 +6,11 @@ import ( // testMetadata implements Metadata for testing. type testMetadata struct { - strings map[string]string - stringArrays map[string][]string - uint32s map[string]uint32 - int32Arrays map[string][]int32 + strings map[string]string + stringArrays map[string][]string + uint32s map[string]uint32 + int32Arrays map[string][]int32 + float32Arrays map[string][]float32 } func (m *testMetadata) GetString(key string) (string, bool) { @@ -32,6 +33,14 @@ func (m *testMetadata) GetInt32Array(key string) ([]int32, bool) { return v, ok } +func (m *testMetadata) GetFloat32Array(key string) ([]float32, bool) { + if m.float32Arrays == nil { + return nil, false + } + v, ok := m.float32Arrays[key] + return v, ok +} + func TestExtractTokenizer(t *testing.T) { m := &testMetadata{ strings: map[string]string{}, @@ -144,6 +153,65 @@ func TestExtractTokenizer_ControlTokens(t *testing.T) { } } +func TestExtractTokenizer_SentencePieceUnigram(t *testing.T) { + // Simulate a Mistral 7B GGUF: llama model with scores but no merges. + m := &testMetadata{ + strings: map[string]string{ + "tokenizer.ggml.model": "llama", + }, + stringArrays: map[string][]string{ + "tokenizer.ggml.tokens": { + "", "", "", + "\u2581Hello", "\u2581world", + "H", "e", "l", "o", "w", "r", "d", + }, + // No merges key at all. + }, + uint32s: map[string]uint32{ + "tokenizer.ggml.bos_token_id": 1, + "tokenizer.ggml.eos_token_id": 2, + "tokenizer.ggml.unknown_token_id": 0, + }, + int32Arrays: map[string][]int32{}, + float32Arrays: map[string][]float32{ + "tokenizer.ggml.scores": { + -100, -100, -100, // , , + -1.0, -1.0, // ▁Hello, ▁world (high score = preferred) + -5.0, -5.0, -5.0, -5.0, -5.0, -5.0, -5.0, // individual chars + }, + }, + } + + tok, err := ExtractTokenizer(m) + if err != nil { + t.Fatalf("ExtractTokenizer error: %v", err) + } + + // Encode "Hello world" should produce [▁Hello, ▁world] = [3, 4]. + ids, err := tok.Encode("Hello world") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + want := []int{3, 4} + if len(ids) != len(want) { + t.Fatalf("Encode(\"Hello world\") = %v (len=%d), want %v", ids, len(ids), want) + } + for i, id := range ids { + if id != want[i] { + t.Errorf("Encode[%d] = %d, want %d", i, id, want[i]) + } + } + + // Decode should round-trip. + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if decoded != "Hello world" { + t.Errorf("Decode = %q, want %q", decoded, "Hello world") + } +} + func TestExtractTokenizer_InvalidMerge(t *testing.T) { m := &testMetadata{ strings: map[string]string{}, diff --git a/loader.go b/loader.go index b6e05f8..8e9939e 100644 --- a/loader.go +++ b/loader.go @@ -20,9 +20,10 @@ type tokenizerJSON struct { } type modelJSON struct { - Type string `json:"type"` - Vocab map[string]int `json:"vocab"` - RawMerges json.RawMessage `json:"merges"` + Type string `json:"type"` + Vocab map[string]int `json:"vocab"` + RawMerges json.RawMessage `json:"merges"` + ContinuingSubwordPrefix string `json:"continuing_subword_prefix"` } type addedTokenJSON struct { @@ -52,7 +53,31 @@ type decoderPatternJSON struct { String string `json:"String"` } +// Load reads a HuggingFace tokenizer.json file and returns the appropriate +// Tokenizer implementation based on the model type (BPE or WordPiece). +func Load(path string) (Tokenizer, error) { + data, err := os.ReadFile(path) //nolint:gosec // user-provided path + if err != nil { + return nil, fmt.Errorf("read tokenizer.json: %w", err) + } + + var tj tokenizerJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, fmt.Errorf("parse tokenizer.json: %w", err) + } + + switch tj.Model.Type { + case "WordPiece": + return loadWordPiece(tj) + case "", "BPE": + return loadBPE(tj) + default: + return nil, fmt.Errorf("unsupported model type: %q (supported: BPE, WordPiece)", tj.Model.Type) + } +} + // LoadFromJSON reads a HuggingFace tokenizer.json file and returns a BPETokenizer. +// For loading any tokenizer type, use [Load] instead. func LoadFromJSON(path string) (*BPETokenizer, error) { data, err := os.ReadFile(path) //nolint:gosec // user-provided path if err != nil { @@ -68,6 +93,11 @@ func LoadFromJSON(path string) (*BPETokenizer, error) { return nil, fmt.Errorf("unsupported model type: %q (only BPE supported)", tj.Model.Type) } + return loadBPE(tj) +} + +// loadBPE constructs a BPETokenizer from parsed JSON. +func loadBPE(tj tokenizerJSON) (*BPETokenizer, error) { // Parse merges — supports both ["a b", …] and [["a","b"], …] formats. merges, err := parseMerges(tj.Model.RawMerges) if err != nil { @@ -94,6 +124,26 @@ func LoadFromJSON(path string) (*BPETokenizer, error) { return tok, nil } +// loadWordPiece constructs a WordPieceTokenizer from parsed JSON. +func loadWordPiece(tj tokenizerJSON) (*WordPieceTokenizer, error) { + special := extractSpecialTokens(tj.AddedTokens) + normalizer := buildNormalizer(tj.Normalizer) + + tok := NewWordPieceTokenizer(tj.Model.Vocab, special) + tok.normalizer = normalizer + + // Register special token strings for exact matching. + specialMap := make(map[string]int) + for _, at := range tj.AddedTokens { + if at.Special { + specialMap[at.Content] = at.ID + } + } + tok.specialTokens = specialMap + + return tok, nil +} + // isByteLevelPreTokenizer returns true if the pre-tokenizer config uses ByteLevel. func isByteLevelPreTokenizer(pt *preTokenizerJSON) bool { if pt == nil { @@ -136,6 +186,8 @@ func isSentencePieceDecoder(d *decoderJSON) bool { } // extractSpecialTokens finds BOS, EOS, PAD, UNK from added_tokens. +// Recognizes both GPT-style (, , , ) and BERT-style +// ([CLS], [SEP], [PAD], [UNK]) special token conventions. func extractSpecialTokens(tokens []addedTokenJSON) SpecialTokens { special := SpecialTokens{} for _, t := range tokens { @@ -143,13 +195,13 @@ func extractSpecialTokens(tokens []addedTokenJSON) SpecialTokens { continue } switch { - case strings.Contains(t.Content, "bos") || t.Content == "": + case strings.Contains(t.Content, "bos") || t.Content == "" || t.Content == "[CLS]": special.BOS = t.ID - case strings.Contains(t.Content, "eos") || t.Content == "": + case strings.Contains(t.Content, "eos") || t.Content == "" || t.Content == "[SEP]": special.EOS = t.ID - case strings.Contains(t.Content, "pad") || t.Content == "": + case strings.Contains(t.Content, "pad") || t.Content == "" || t.Content == "[PAD]": special.PAD = t.ID - case strings.Contains(t.Content, "unk") || t.Content == "": + case strings.Contains(t.Content, "unk") || t.Content == "" || t.Content == "[UNK]": special.UNK = t.ID } } diff --git a/wordpiece.go b/wordpiece.go new file mode 100644 index 0000000..52d686f --- /dev/null +++ b/wordpiece.go @@ -0,0 +1,272 @@ +package ztoken + +import ( + "fmt" + "strings" + "unicode" +) + +// WordPieceTokenizer implements the Tokenizer interface using the WordPiece +// algorithm, as used by BERT-family models. It greedily matches the longest +// subword prefix from the vocabulary, using "##" to denote continuation tokens. +// +// Stable. +type WordPieceTokenizer struct { + vocab map[string]int + reverseVocab map[int]string + special SpecialTokens + normalizer NormalizerFunc + // maxTokenLen is the length of the longest token in the vocabulary, + // used to bound the greedy prefix search. + maxTokenLen int + // unkToken is the string representation of the unknown token. + unkToken string + // specialTokens maps special token strings to IDs for exact matching. + specialTokens map[string]int +} + +// BERTEncoding holds the input tensors expected by BERT-family models. +type BERTEncoding struct { + InputIDs []int // Token IDs: [CLS] + tokens + [SEP] (+ tokens + [SEP] for pairs) + AttentionMask []int // 1 for real tokens, 0 for padding + TokenTypeIDs []int // 0 for first sentence, 1 for second sentence +} + +// NewWordPieceTokenizer creates a WordPieceTokenizer from a vocabulary and special tokens. +func NewWordPieceTokenizer(vocab map[string]int, special SpecialTokens) *WordPieceTokenizer { + reverseVocab := make(map[int]string, len(vocab)) + maxLen := 0 + for k, v := range vocab { + reverseVocab[v] = k + if len(k) > maxLen { + maxLen = len(k) + } + } + + unkToken := "[UNK]" + if tok, ok := reverseVocab[special.UNK]; ok { + unkToken = tok + } + + return &WordPieceTokenizer{ + vocab: vocab, + reverseVocab: reverseVocab, + special: special, + maxTokenLen: maxLen, + unkToken: unkToken, + } +} + +// Encode tokenizes text into a sequence of token IDs using WordPiece. +func (t *WordPieceTokenizer) Encode(text string) ([]int, error) { + if text == "" { + return []int{}, nil + } + if t.normalizer != nil { + text = t.normalizer(text) + } + words := preTokenize(text) + var ids []int + for _, word := range words { + wordIDs := t.tokenizeWord(word) + ids = append(ids, wordIDs...) + } + return ids, nil +} + +// Decode converts token IDs back to text. Continuation tokens (##prefix) are +// joined without spaces to reconstruct words. +func (t *WordPieceTokenizer) Decode(ids []int) (string, error) { + var sb strings.Builder + for i, id := range ids { + tok, ok := t.reverseVocab[id] + if !ok { + return "", fmt.Errorf("unknown token ID: %d", id) + } + // Skip special tokens in decode output. + if t.isSpecialToken(tok) { + continue + } + if strings.HasPrefix(tok, "##") { + sb.WriteString(tok[2:]) + } else { + if i > 0 && sb.Len() > 0 { + sb.WriteByte(' ') + } + sb.WriteString(tok) + } + } + return sb.String(), nil +} + +// VocabSize returns the number of tokens in the vocabulary. +func (t *WordPieceTokenizer) VocabSize() int { + return len(t.vocab) +} + +// GetToken returns the string token for a given ID. +func (t *WordPieceTokenizer) GetToken(id int) (string, bool) { + tok, ok := t.reverseVocab[id] + return tok, ok +} + +// GetID returns the token ID for a given string. +func (t *WordPieceTokenizer) GetID(token string) (int, bool) { + id, ok := t.vocab[token] + return id, ok +} + +// SpecialTokens returns the special token IDs. +func (t *WordPieceTokenizer) SpecialTokens() SpecialTokens { + return t.special +} + +// EncodeForBERT tokenizes one or two sentences into the BERT input format. +// For a single sentence: [CLS] tokens [SEP] +// For a sentence pair: [CLS] tokens_a [SEP] tokens_b [SEP] +// The result is padded to maxLen if maxLen > 0. +func (t *WordPieceTokenizer) EncodeForBERT(textA string, textB string, maxLen int) (*BERTEncoding, error) { + idsA, err := t.Encode(textA) + if err != nil { + return nil, fmt.Errorf("encode text_a: %w", err) + } + + clsID, ok := t.vocab["[CLS]"] + if !ok { + return nil, fmt.Errorf("vocabulary missing [CLS] token") + } + sepID, ok := t.vocab["[SEP]"] + if !ok { + return nil, fmt.Errorf("vocabulary missing [SEP] token") + } + + // Build input_ids: [CLS] + tokens_a + [SEP] + inputIDs := make([]int, 0, len(idsA)+3) + inputIDs = append(inputIDs, clsID) + inputIDs = append(inputIDs, idsA...) + inputIDs = append(inputIDs, sepID) + + // token_type_ids: 0 for first sentence segment + tokenTypeIDs := make([]int, len(inputIDs)) + + if textB != "" { + idsB, err := t.Encode(textB) + if err != nil { + return nil, fmt.Errorf("encode text_b: %w", err) + } + secondStart := len(inputIDs) + inputIDs = append(inputIDs, idsB...) + inputIDs = append(inputIDs, sepID) + // Extend token_type_ids: 1 for second sentence segment + for range len(idsB) + 1 { + tokenTypeIDs = append(tokenTypeIDs, 1) + } + _ = secondStart + } + + seqLen := len(inputIDs) + + // Pad if maxLen specified. + if maxLen > 0 && seqLen < maxLen { + padCount := maxLen - seqLen + for range padCount { + inputIDs = append(inputIDs, t.special.PAD) + tokenTypeIDs = append(tokenTypeIDs, 0) + } + } + + // Build attention_mask: 1 for real tokens, 0 for padding. + attentionMask := make([]int, len(inputIDs)) + for i := range seqLen { + attentionMask[i] = 1 + } + + return &BERTEncoding{ + InputIDs: inputIDs, + AttentionMask: attentionMask, + TokenTypeIDs: tokenTypeIDs, + }, nil +} + +// tokenizeWord applies the WordPiece algorithm to a single pre-tokenized word. +// It greedily matches the longest prefix in the vocabulary, continuing with +// "##"-prefixed subwords for the remainder. +func (t *WordPieceTokenizer) tokenizeWord(word string) []int { + if _, ok := t.vocab[word]; ok { + return []int{t.vocab[word]} + } + + var ids []int + start := 0 + runes := []rune(word) + + for start < len(runes) { + end := len(runes) + if end-start > t.maxTokenLen { + end = start + t.maxTokenLen + } + matched := false + for end > start { + substr := string(runes[start:end]) + if start > 0 { + substr = "##" + substr + } + if id, ok := t.vocab[substr]; ok { + ids = append(ids, id) + start = end + matched = true + break + } + end-- + } + if !matched { + // No subword match found — entire remaining word is UNK. + return []int{t.special.UNK} + } + } + return ids +} + +// isSpecialToken returns true if the token string is a known special token. +func (t *WordPieceTokenizer) isSpecialToken(tok string) bool { + switch tok { + case "[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]": + return true + } + if _, ok := t.specialTokens[tok]; ok { + return true + } + return false +} + +// preTokenize splits text on whitespace and punctuation boundaries, +// producing individual words and punctuation characters as separate tokens. +func preTokenize(text string) []string { + var tokens []string + var current strings.Builder + for _, r := range text { + if unicode.IsSpace(r) { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + continue + } + if unicode.IsPunct(r) { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + tokens = append(tokens, string(r)) + continue + } + current.WriteRune(r) + } + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + return tokens +} + +// Statically assert WordPieceTokenizer implements Tokenizer. +var _ Tokenizer = (*WordPieceTokenizer)(nil) diff --git a/wordpiece_test.go b/wordpiece_test.go new file mode 100644 index 0000000..6923db2 --- /dev/null +++ b/wordpiece_test.go @@ -0,0 +1,362 @@ +package ztoken + +import ( + "os" + "path/filepath" + "testing" +) + +func testWordPieceVocab() map[string]int { + return map[string]int{ + "[PAD]": 0, + "[UNK]": 1, + "[CLS]": 2, + "[SEP]": 3, + "[MASK]": 4, + "hello": 5, + "world": 6, + "un": 7, + "##aff": 8, + "##able": 9, + "the": 10, + "##s": 11, + "cat": 12, + "dog": 13, + "play": 14, + "##ing": 15, + "a": 16, + ",": 17, + ".": 18, + } +} + +func testWordPieceTokenizer() *WordPieceTokenizer { + vocab := testWordPieceVocab() + special := SpecialTokens{ + BOS: 2, // [CLS] + EOS: 3, // [SEP] + PAD: 0, // [PAD] + UNK: 1, // [UNK] + } + return NewWordPieceTokenizer(vocab, special) +} + +func TestWordPieceTokenizer_Encode(t *testing.T) { + tok := testWordPieceTokenizer() + + tests := []struct { + name string + input string + want []int + }{ + {"single word", "hello", []int{5}}, + {"two words", "hello world", []int{5, 6}}, + {"subword split", "unaffable", []int{7, 8, 9}}, + {"unknown word", "xyzzy", []int{1}}, + {"empty string", "", []int{}}, + {"punctuation split", "hello,world", []int{5, 17, 6}}, + {"continuation suffix", "cats", []int{12, 11}}, + {"playing", "playing", []int{14, 15}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ids, err := tok.Encode(tt.input) + if err != nil { + t.Fatalf("Encode(%q) error: %v", tt.input, err) + } + if len(ids) != len(tt.want) { + t.Fatalf("Encode(%q) = %v (len=%d), want %v (len=%d)", + tt.input, ids, len(ids), tt.want, len(tt.want)) + } + for i, id := range ids { + if id != tt.want[i] { + t.Errorf("Encode(%q)[%d] = %d, want %d", tt.input, i, id, tt.want[i]) + } + } + }) + } +} + +func TestWordPieceTokenizer_Decode(t *testing.T) { + tok := testWordPieceTokenizer() + + tests := []struct { + name string + ids []int + want string + }{ + {"single word", []int{5}, "hello"}, + {"two words", []int{5, 6}, "hello world"}, + {"subword joined", []int{7, 8, 9}, "unaffable"}, + {"with continuation", []int{12, 11}, "cats"}, + {"skip CLS/SEP", []int{2, 5, 3}, "hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoded, err := tok.Decode(tt.ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", tt.ids, err) + } + if decoded != tt.want { + t.Errorf("Decode(%v) = %q, want %q", tt.ids, decoded, tt.want) + } + }) + } +} + +func TestWordPieceTokenizer_RoundTrip(t *testing.T) { + tok := testWordPieceTokenizer() + + texts := []string{ + "hello world", + "the cat", + "playing", + "cats", + "unaffable", + } + + for _, text := range texts { + ids, err := tok.Encode(text) + if err != nil { + t.Fatalf("Encode(%q) error: %v", text, err) + } + decoded, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v) error: %v", ids, err) + } + if decoded != text { + t.Errorf("round-trip: %q -> %v -> %q", text, ids, decoded) + } + } +} + +func TestWordPieceTokenizer_VocabSize(t *testing.T) { + tok := testWordPieceTokenizer() + if got := tok.VocabSize(); got != 19 { + t.Errorf("VocabSize() = %d, want 19", got) + } +} + +func TestWordPieceTokenizer_GetToken(t *testing.T) { + tok := testWordPieceTokenizer() + tok1, ok := tok.GetToken(5) + if !ok || tok1 != "hello" { + t.Errorf("GetToken(5) = (%q, %v), want (\"hello\", true)", tok1, ok) + } + _, ok = tok.GetToken(999) + if ok { + t.Error("GetToken(999) should return false") + } +} + +func TestWordPieceTokenizer_GetID(t *testing.T) { + tok := testWordPieceTokenizer() + id, ok := tok.GetID("hello") + if !ok || id != 5 { + t.Errorf("GetID(\"hello\") = (%d, %v), want (5, true)", id, ok) + } + _, ok = tok.GetID("nonexistent") + if ok { + t.Error("GetID(\"nonexistent\") should return false") + } +} + +func TestWordPieceTokenizer_SpecialTokens(t *testing.T) { + tok := testWordPieceTokenizer() + sp := tok.SpecialTokens() + if sp.BOS != 2 { + t.Errorf("BOS = %d, want 2", sp.BOS) + } + if sp.EOS != 3 { + t.Errorf("EOS = %d, want 3", sp.EOS) + } + if sp.PAD != 0 { + t.Errorf("PAD = %d, want 0", sp.PAD) + } + if sp.UNK != 1 { + t.Errorf("UNK = %d, want 1", sp.UNK) + } +} + +func TestWordPieceTokenizer_EncodeForBERT_Single(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello world", "", 0) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, world=6, [SEP]=3 + wantIDs := []int{2, 5, 6, 3} + wantMask := []int{1, 1, 1, 1} + wantTypes := []int{0, 0, 0, 0} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_EncodeForBERT_Pair(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello", "world", 0) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, [SEP]=3, world=6, [SEP]=3 + wantIDs := []int{2, 5, 3, 6, 3} + wantMask := []int{1, 1, 1, 1, 1} + wantTypes := []int{0, 0, 0, 1, 1} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_EncodeForBERT_Padding(t *testing.T) { + tok := testWordPieceTokenizer() + + enc, err := tok.EncodeForBERT("hello", "", 8) + if err != nil { + t.Fatalf("EncodeForBERT error: %v", err) + } + + // [CLS]=2, hello=5, [SEP]=3, [PAD]=0 x5 + wantIDs := []int{2, 5, 3, 0, 0, 0, 0, 0} + wantMask := []int{1, 1, 1, 0, 0, 0, 0, 0} + wantTypes := []int{0, 0, 0, 0, 0, 0, 0, 0} + + assertIntSlice(t, "InputIDs", enc.InputIDs, wantIDs) + assertIntSlice(t, "AttentionMask", enc.AttentionMask, wantMask) + assertIntSlice(t, "TokenTypeIDs", enc.TokenTypeIDs, wantTypes) +} + +func TestWordPieceTokenizer_PreTokenize(t *testing.T) { + tests := []struct { + input string + want []string + }{ + {"hello world", []string{"hello", "world"}}, + {"hello,world", []string{"hello", ",", "world"}}, + {"hello. world!", []string{"hello", ".", "world", "!"}}, + {" spaces ", []string{"spaces"}}, + {"", nil}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := preTokenize(tt.input) + if len(got) != len(tt.want) { + t.Fatalf("preTokenize(%q) = %v, want %v", tt.input, got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("preTokenize(%q)[%d] = %q, want %q", tt.input, i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestLoad_WordPiece(t *testing.T) { + fixture := `{ + "model": { + "type": "WordPiece", + "vocab": { + "[PAD]": 0, "[UNK]": 1, "[CLS]": 2, "[SEP]": 3, "[MASK]": 4, + "hello": 5, "world": 6, "un": 7, "##aff": 8, "##able": 9 + } + }, + "added_tokens": [ + {"id": 0, "content": "[PAD]", "special": true}, + {"id": 1, "content": "[UNK]", "special": true}, + {"id": 2, "content": "[CLS]", "special": true}, + {"id": 3, "content": "[SEP]", "special": true}, + {"id": 4, "content": "[MASK]", "special": true} + ], + "normalizer": { + "type": "Lowercase" + } +}` + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(fixture), 0o600); err != nil { + t.Fatal(err) + } + + tok, err := Load(path) + if err != nil { + t.Fatalf("Load error: %v", err) + } + + wp, ok := tok.(*WordPieceTokenizer) + if !ok { + t.Fatalf("expected *WordPieceTokenizer, got %T", tok) + } + + // Normalizer should lowercase. + ids, err := wp.Encode("HELLO") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + if len(ids) != 1 || ids[0] != 5 { + t.Errorf("Encode(\"HELLO\") = %v, want [5]", ids) + } + + // Subword splitting. + ids, err = wp.Encode("unaffable") + if err != nil { + t.Fatalf("Encode error: %v", err) + } + wantIDs := []int{7, 8, 9} + assertIntSlice(t, "subword", ids, wantIDs) + + // Special tokens. + sp := tok.SpecialTokens() + if sp.BOS != 2 || sp.EOS != 3 || sp.PAD != 0 || sp.UNK != 1 { + t.Errorf("SpecialTokens = %+v", sp) + } +} + +func TestLoad_BPE(t *testing.T) { + // Load should still work for BPE models. + tok, err := Load(filepath.Join("testdata", "tokenizer.json")) + if err != nil { + t.Fatalf("Load error: %v", err) + } + if _, ok := tok.(*BPETokenizer); !ok { + t.Fatalf("expected *BPETokenizer, got %T", tok) + } +} + +func TestLoad_UnsupportedType(t *testing.T) { + fixture := `{ + "model": {"type": "Unigram", "vocab": {}}, + "added_tokens": [] +}` + dir := t.TempDir() + path := filepath.Join(dir, "tokenizer.json") + if err := os.WriteFile(path, []byte(fixture), 0o600); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for Unigram model type") + } +} + +func assertIntSlice(t *testing.T, name string, got, want []int) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("%s: len=%d, want len=%d; got %v, want %v", name, len(got), len(want), got, want) + } + for i := range got { + if got[i] != want[i] { + t.Errorf("%s[%d] = %d, want %d", name, i, got[i], want[i]) + } + } +}