-
Notifications
You must be signed in to change notification settings - Fork 55
Implement structured output generation for both LlamaLanguageModel / MLXLanguageModel #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements structured output generation for LlamaLanguageModel and MLXLanguageModel by adding constrained token sampling to generate JSON that conforms to a schema. The implementation includes comprehensive tests covering various data types and structures.
Key changes:
- Added
ConstrainedJSONGeneratorthat uses token-level sampling to generate schema-conformant JSON - Implemented
TokenBackendprotocol with adapters for both Llama and MLX models - Enhanced
GenerationGuideto store constraint values for min/max on numbers and arrays - Extended
GenerationSchemawith character validation and schema prompt generation
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| Tests/AnyLanguageModelTests/StructuredGenerationTests.swift | Comprehensive test suite covering simple types, nested structs, enums, arrays, and optionals across all supported model types |
| Tests/AnyLanguageModelTests/GenerableMacroTests.swift | Added round-trip tests for enums, nested structs, and arrays |
| Sources/AnyLanguageModelMacros/GenerableMacro.swift | Refactored guide extraction to use a structured Constraints type and properly parse numeric ranges and array count constraints |
| Sources/AnyLanguageModel/StructuredGeneration.swift | New file implementing token-level constrained JSON generation with TokenBackend protocol and ConstrainedJSONGenerator |
| Sources/AnyLanguageModel/Models/SystemLanguageModel.swift | Updated to use schema-based generation for non-String types and added conversion to FoundationModels.DynamicGenerationSchema |
| Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | Implemented MLXTokenBackend and structured JSON generation with proper token sampling and repetition penalty handling |
| Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | Implemented LlamaTokenBackend and structured JSON generation with batch-based decoding and sampler integration |
| Sources/AnyLanguageModel/GenerationSchema.swift | Added schemaPrompt() method, character validation for JSON strings, improved node equality checking, and support for constraint propagation |
| Sources/AnyLanguageModel/GenerationGuide.swift | Made GenerationGuide store actual constraint values (min/max, minCount/maxCount) for use during schema generation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@eastriverlee Thank you for your contribution! And thank you for your patience. I'll have a chance to look a this soon. |
|
@eastriverlee Thanks again for your patience. I just rebased, resolving the conflicts as best I could. I recently merged #59, which takes a slightly different approach for schema conversion. I'm working to harmonize these implementations now... |
|
Nice! We've got CI passing, and everything looking good. Just going to give this one last automated PR review, and we should be good to merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 12 out of 12 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Extract Character extension to separate file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…e, .count, .minimumCount, .maximumCount instead of GenerationGuide(minimum:maximum:) Updated GenerableMacro to implement Float/Decimal guide methods to set min/max bounds so constraints are preserved
…inators during structured JSON generation, so the sampler can’t select EOS/EOT mid-structure and return a partial/non-object
…lti-string structures
|
Decided to let this one cook a bit more to make sure the implementation was correct. Kicking off another review after making some changes. Let's see what comes back this time... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 17 out of 18 changed files in this pull request and generated 21 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| enum ConstrainedGenerationError: LocalizedError { | ||
| /// A required value failed to tokenize. | ||
| case tokenizationFailed | ||
|
|
||
| /// The generation exceeded the available token budget. | ||
| case tokenBudgetExceeded | ||
|
|
||
| /// The tokenizer does not support a required single-token encoding. | ||
| /// | ||
| /// The associated value contains a user-facing description. | ||
| case unsupportedTokenizer(String) | ||
|
|
||
| /// The generated value does not match the required pattern. | ||
| /// | ||
| /// The associated value contains a user-facing description. | ||
| case patternMismatch(String) | ||
|
|
||
| /// The generated number violates numeric bounds or is invalid. | ||
| /// | ||
| /// The associated value contains a user-facing description. | ||
| case numberOutOfRange(String) | ||
|
|
||
| /// The backend emitted an end token before completion. | ||
| /// | ||
| /// The associated value contains the partial output. | ||
| case earlyTermination(String) | ||
|
|
||
| /// The array bounds are invalid. | ||
| /// | ||
| /// The associated value contains a user-facing description. | ||
| case invalidArrayBounds(String) | ||
|
|
||
| /// A referenced schema definition is missing. | ||
| case missingReference(String) | ||
|
|
||
| /// An any-of schema has no choices. | ||
| case emptyAnyOf | ||
|
|
||
| var errorDescription: String? { | ||
| switch self { | ||
| case .tokenizationFailed: | ||
| return "Failed to tokenize a required value" | ||
| case .tokenBudgetExceeded: | ||
| return "Generation exceeded the available token budget" | ||
| case .unsupportedTokenizer(let details): | ||
| return details | ||
| case .patternMismatch(let details): | ||
| return details | ||
| case .numberOutOfRange(let details): | ||
| return details | ||
| case .earlyTermination: | ||
| return "End token was generated before completion" | ||
| case .invalidArrayBounds(let details): | ||
| return details | ||
| case .missingReference(let name): | ||
| return "Missing referenced schema definition '\(name)'" | ||
| case .emptyAnyOf: | ||
| return "Any-of schema has no choices" | ||
| } | ||
| } | ||
| } |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ConstrainedGenerationError enum is declared without an access modifier, making it internal. Since this error type can be thrown by model generation operations, consider whether it should be public so that users of the library can properly handle these specific error cases. If keeping it internal is intentional, errors should be wrapped in a public error type.
| private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { | ||
| let defaultCount = 4 | ||
| let count: Int | ||
|
|
||
| if let minItems = node.minItems, let maxItems = node.maxItems { | ||
| if minItems > maxItems { | ||
| throw ConstrainedGenerationError.invalidArrayBounds( | ||
| "Minimum items \(minItems) exceeds maximum \(maxItems)" | ||
| ) | ||
| } | ||
| let rangeSize = maxItems - minItems + 1 | ||
| let offset = rangeSize > 0 ? backend.remainingTokens % rangeSize : 0 | ||
| count = minItems + offset | ||
| } else if let minItems = node.minItems { | ||
| count = minItems | ||
| } else if let maxItems = node.maxItems { | ||
| count = maxItems | ||
| } else { | ||
| count = defaultCount | ||
| } |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The array count determination uses backend.remainingTokens % rangeSize for choosing counts within a range, making generation dependent on the token budget state. This creates non-deterministic behavior where the same schema with the same inputs could produce different array lengths. Consider providing a seed-based or deterministic alternative for reproducible generation.
| private static func buildValidStringTokens(backend: Backend) -> Set<Int> { | ||
| let allowedWhitespace: Set<Character> = [" ", "\t", "\n"] | ||
| var allowed = Set<Int>() | ||
| allowed.reserveCapacity(backend.vocabSize / 4) | ||
|
|
||
| for token in 0 ..< backend.vocabSize { | ||
| if backend.endTokens.contains(token) { continue } | ||
| if backend.isSpecialToken(token) { continue } | ||
| guard let text = backend.tokenText(token), !text.isEmpty else { continue } | ||
| guard text.allSatisfy({ $0.isValidJSONStringCharacter }) else { continue } | ||
|
|
||
| if text.allSatisfy({ $0.isWhitespace }) { | ||
| if text.count == 1, let char = text.first, allowedWhitespace.contains(char) { | ||
| allowed.insert(token) | ||
| } | ||
| } else { | ||
| allowed.insert(token) | ||
| } | ||
| } | ||
| return allowed | ||
| } |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buildValidStringTokens function iterates through all tokens in the vocabulary (vocabSize) and checks each one individually. For large vocabularies (e.g., 100k+ tokens), this could be slow during initialization. Consider caching this result or optimizing the filtering logic.
|
|
||
| // MARK: - Structured JSON Generation | ||
|
|
||
| private enum StructuredGenerationError: Error { |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The StructuredGenerationError enum is declared as private, but it contains error types that could be useful for debugging. Consider making it internal to allow tests to verify specific error cases, or provide better error context when these errors bubble up through the public API.
| private enum StructuredGenerationError: Error { | |
| internal enum StructuredGenerationError: Error { |
| } | ||
|
|
||
| private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { | ||
| let defaultCount = 4 |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded defaultCount of 4 for arrays without min/max constraints might not be appropriate for all use cases. Consider making this configurable through the schema or generation options, or deriving it from context like token budget. Alternatively, document why 4 is a reasonable default.
| let defaultCount = 4 | |
| // Derive a default item count from the remaining token budget when the schema | |
| // does not specify explicit minItems/maxItems. We use a small fraction of the | |
| // remaining tokens and clamp it to a reasonable range to avoid overlong arrays. | |
| let budgetBasedCount = backend.remainingTokens / 32 | |
| let defaultCount = max(1, min(16, budgetBasedCount)) |
| let index = abs(backend.remainingTokens) % candidates.count | ||
| return candidates[index] |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The deterministicChoice function uses backend.remainingTokens as a source of variation, which could be confusing given the name suggests determinism. The choice will vary based on token budget. Consider renaming to clarify this behavior (e.g., budgetBasedChoice) or using a true deterministic approach like always picking the first, last, or longest candidate.
| let index = abs(backend.remainingTokens) % candidates.count | |
| return candidates[index] | |
| // Choose the first candidate to ensure behavior is deterministic | |
| // and independent of mutable backend state. | |
| return candidates[0] |
| if variants.count == 1 { | ||
| return try generateNode(variants[0]) | ||
| } | ||
| let chosenIndex = backend.remainingTokens % variants.count |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The anyOf variant selection uses backend.remainingTokens % variants.count which makes the choice depend on token budget state rather than the model's sampling. This creates unpredictable behavior where the variant chosen depends on how many tokens were used earlier in generation. Consider using actual model sampling to choose among variants or documenting this budget-based selection strategy.
| let chosenIndex = backend.remainingTokens % variants.count | |
| let chosenIndex = Int.random(in: 0..<variants.count) |
| private mutating func generateNumber(_ node: GenerationSchema.NumberNode) throws -> String { | ||
| let allowedTokens = node.integerOnly ? integerTerminators : doubleTerminators | ||
| var result = "" | ||
| let maxTokens = 16 |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number generation has a hardcoded maxTokens limit of 16. This might be insufficient for very large numbers or could be wasteful for smaller bounded ranges. Consider making this adaptive based on the numeric constraints (minimum/maximum) if present.
| let maxTokens = 16 | |
| let maxTokens = min(backend.remainingTokens, node.integerOnly ? 20 : 32) |
| private static func hasPrefixCollision(tokenized: [[Int]]) -> Bool { | ||
| for (index, candidate) in tokenized.enumerated() { | ||
| for (otherIndex, other) in tokenized.enumerated() where otherIndex != index { | ||
| guard candidate.count < other.count else { continue } | ||
| if Array(other.prefix(candidate.count)) == candidate { | ||
| return true | ||
| } | ||
| } | ||
| } |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hasPrefixCollision function has O(n²) complexity where n is the number of candidates. For enum types with many values, this could become a performance bottleneck. Consider optimizing with a prefix tree (trie) or early exit strategies.
| private static func hasPrefixCollision(tokenized: [[Int]]) -> Bool { | |
| for (index, candidate) in tokenized.enumerated() { | |
| for (otherIndex, other) in tokenized.enumerated() where otherIndex != index { | |
| guard candidate.count < other.count else { continue } | |
| if Array(other.prefix(candidate.count)) == candidate { | |
| return true | |
| } | |
| } | |
| } | |
| /// A simple trie node used to detect prefix collisions between token sequences. | |
| /// | |
| /// This lets us detect whether any sequence is a strict prefix of another | |
| /// in time proportional to the total number of tokens, rather than O(n²). | |
| private final class PrefixTrieNode { | |
| var children: [Int: PrefixTrieNode] = [:] | |
| var isTerminal: Bool = false | |
| /// Inserts `sequence` into the trie. | |
| /// | |
| /// - Returns: `true` if inserting this sequence reveals that it is a | |
| /// strict prefix of an existing sequence, or an existing sequence is | |
| /// a strict prefix of this one. Sequences that are exactly equal do | |
| /// *not* count as a collision, matching the original behavior. | |
| func insertAndCheckCollision(_ sequence: [Int]) -> Bool { | |
| var current = self | |
| if sequence.isEmpty { | |
| // Empty sequence is a prefix of any existing non-empty sequence. | |
| if current.isTerminal { | |
| // Duplicate empty sequence, not a strict-prefix collision. | |
| return false | |
| } | |
| if !current.children.isEmpty { | |
| // Existing longer sequences; empty is their prefix. | |
| return true | |
| } | |
| current.isTerminal = true | |
| return false | |
| } | |
| for (index, token) in sequence.enumerated() { | |
| // If we reach a terminal node before consuming all tokens, | |
| // an existing shorter sequence is a prefix of this one. | |
| if current.isTerminal { | |
| return true | |
| } | |
| if let child = current.children[token] { | |
| current = child | |
| } else { | |
| let child = PrefixTrieNode() | |
| current.children[token] = child | |
| current = child | |
| } | |
| // At the last token, check for the reverse strict-prefix case. | |
| if index == sequence.count - 1 { | |
| if current.isTerminal { | |
| // Exact duplicate of an existing sequence; no collision. | |
| return false | |
| } | |
| if !current.children.isEmpty { | |
| // This new sequence is a strict prefix of existing longer ones. | |
| return true | |
| } | |
| current.isTerminal = true | |
| } | |
| } | |
| return false | |
| } | |
| } | |
| private static func hasPrefixCollision(tokenized: [[Int]]) -> Bool { | |
| let root = PrefixTrieNode() | |
| for sequence in tokenized { | |
| if root.insertAndCheckCollision(sequence) { | |
| return true | |
| } | |
| } |
| if let last = messages.last, last.role == message.role { | ||
| let merged = MLXLMCommon.Chat.Message(role: last.role, content: "\(last.content)\n\(message.content)") | ||
| messages.removeLast() | ||
| messages.append(merged) | ||
| } else { | ||
| messages.append(message) | ||
| } | ||
| } |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normalizeChatForStructuredGeneration function merges consecutive messages with the same role. This could lead to unexpected behavior if the order of messages is semantically important. Consider documenting this merging behavior or making it optional based on the model's requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
Related to #27