Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ let package = Package(
name: "mlx-server",
platforms: [.macOS(.v14)],
dependencies: [
// Apple MLX Swift — core inference engine
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.3")),
// Apple's LLM library built on MLX Swift (Qwen, Llama, Mistral, Gemma etc.)
.package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.0.0"),
// Apple MLX Swift — core inference engine (Apple-maintained, tagged releases)
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.6")),
// Apple's LLM library built on MLX Swift (SharpAI fork)
// Pinned to main branch for Qwen3.5 support (PRs #97, #120, #129, #133, #135 — not yet in a release tag)
.package(url: "https://github.com/SharpAI/mlx-swift-lm", branch: "main"),
// HuggingFace tokenizers + model download
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.1.0")),
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "1.2.0")),
// Lightweight HTTP server (Apple-backed Swift server project)
.package(url: "https://github.com/hummingbird-project/hummingbird", from: "2.0.0"),
// Async argument parser (for CLI flags: --model, --port)
Expand Down
143 changes: 133 additions & 10 deletions Sources/mlx-server/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ struct MLXServer: AsyncParsableCommand {
@Option(name: .long, help: "Number of parallel request slots")
var parallel: Int = 1

@Flag(name: .long, help: "Enable thinking/reasoning mode (Qwen3.5 etc). Default: disabled")
var thinking: Bool = false

mutating func run() async throws {
print("[mlx-server] Loading model: \(model)")
let modelId = model
Expand All @@ -72,6 +75,7 @@ struct MLXServer: AsyncParsableCommand {
let defaultTemp = self.temp
let defaultTopP = self.topP
let defaultRepeatPenalty = self.repeatPenalty
let thinkingEnabled = self.thinking
let parallelSlots = self.parallel

// ── Concurrency limiter ──
Expand Down Expand Up @@ -138,27 +142,48 @@ struct MLXServer: AsyncParsableCommand {
}
}

// Convert OpenAI tools format → [String: any Sendable] for UserInput
let toolSpecs: [[String: any Sendable]]? = chatReq.tools?.map { tool in
var spec: [String: any Sendable] = ["type": tool.type]
var fn: [String: any Sendable] = ["name": tool.function.name]
if let desc = tool.function.description { fn["description"] = desc }
if let params = tool.function.parameters {
fn["parameters"] = params.mapValues { $0.value }
}
spec["function"] = fn
return spec
}

// ── Acquire slot (concurrency limiter) ──
await semaphore.wait()

let userInput = UserInput(chat: chatMessages)
// Pass enable_thinking to the Jinja chat template via additionalContext
// (mirrors llama-server's --chat-template-kwargs '{"enable_thinking":false}')
let templateContext: [String: any Sendable]? = thinkingEnabled ? nil : ["enable_thinking": false]
let userInput = UserInput(chat: chatMessages, tools: toolSpecs, additionalContext: templateContext)
let lmInput = try await container.prepare(input: userInput)
let stream = try await container.generate(input: lmInput, parameters: params)

if isStream {
// SSE streaming
let (sseStream, cont) = AsyncStream<String>.makeStream()
Task {
var hasToolCalls = false
var toolCallIndex = 0
for await generation in stream {
switch generation {
case .chunk(let text):
cont.yield(sseChunk(modelId: modelId, delta: text, finishReason: nil))
case .toolCall(let tc):
hasToolCalls = true
let argsJson = serializeToolCallArgs(tc.function.arguments)
cont.yield(sseToolCallChunk(modelId: modelId, index: toolCallIndex, name: tc.function.name, arguments: argsJson))
toolCallIndex += 1
case .info:
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: "stop"))
let reason = hasToolCalls ? "tool_calls" : "stop"
cont.yield(sseChunk(modelId: modelId, delta: "", finishReason: reason))
cont.yield("data: [DONE]\n\n")
cont.finish()
case .toolCall:
break
}
}
cont.finish()
Expand All @@ -170,15 +195,25 @@ struct MLXServer: AsyncParsableCommand {
body: .init(asyncSequence: sseStream.map { ByteBuffer(string: $0) })
)
} else {
// Non-streaming: collect all chunks
// Non-streaming: collect all chunks and tool calls
var fullText = ""
var completionTokenCount = 0
var collectedToolCalls: [ToolCallResponse] = []
var tcIndex = 0
for await generation in stream {
switch generation {
case .chunk(let text):
fullText += text
completionTokenCount += 1
case .info, .toolCall:
case .toolCall(let tc):
let argsJson = serializeToolCallArgs(tc.function.arguments)
collectedToolCalls.append(ToolCallResponse(
id: "call_\(UUID().uuidString.prefix(8))",
type: "function",
function: ToolCallFunction(name: tc.function.name, arguments: argsJson)
))
tcIndex += 1
case .info:
break
}
}
Expand All @@ -189,15 +224,20 @@ struct MLXServer: AsyncParsableCommand {
let estimatedPromptTokens = max(1, promptText.count / 4)
let totalTokens = estimatedPromptTokens + completionTokenCount

let hasToolCalls = !collectedToolCalls.isEmpty
let resp = ChatCompletionResponse(
id: "chatcmpl-\(UUID().uuidString)",
model: modelId,
created: Int(Date().timeIntervalSince1970),
choices: [
Choice(
index: 0,
message: AssistantMessage(role: "assistant", content: fullText),
finishReason: "stop"
message: AssistantMessage(
role: "assistant",
content: fullText.isEmpty && hasToolCalls ? nil : fullText,
toolCalls: hasToolCalls ? collectedToolCalls : nil
),
finishReason: hasToolCalls ? "tool_calls" : "stop"
)
],
usage: TokenUsage(promptTokens: estimatedPromptTokens, completionTokens: completionTokenCount, totalTokens: totalTokens)
Expand Down Expand Up @@ -312,23 +352,68 @@ func sseChunk(modelId: String, delta: String, finishReason: String?) -> String {
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
}

func sseToolCallChunk(modelId: String, index: Int, name: String, arguments: String) -> String {
let chunk: [String: Any] = [
"id": "chatcmpl-\(UUID().uuidString)",
"object": "chat.completion.chunk",
"created": Int(Date().timeIntervalSince1970),
"model": modelId,
"choices": [[
"index": 0,
"delta": [
"role": "assistant",
"tool_calls": [[
"index": index,
"id": "call_\(UUID().uuidString.prefix(8))",
"type": "function",
"function": [
"name": name,
"arguments": arguments,
] as [String: Any],
] as [String: Any]],
] as [String: Any],
] as [String: Any]]
]
let data = try! JSONSerialization.data(withJSONObject: chunk)
return "data: \(String(data: data, encoding: .utf8)!)\n\n"
}

/// Serialize ToolCall arguments ([String: JSONValue]) to a JSON string
func serializeToolCallArgs(_ args: [String: JSONValue]) -> String {
let anyDict = args.mapValues { $0.anyValue }
guard let data = try? JSONSerialization.data(withJSONObject: anyDict) else {
return "{}"
}
return String(data: data, encoding: .utf8) ?? "{}"
}

// ── OpenAI-compatible types ───────────────────────────────────────────────────

struct ChatCompletionRequest: Decodable {
struct Message: Decodable {
let role: String
let content: String
}
struct ToolDef: Decodable {
let type: String
let function: ToolFuncDef
}
struct ToolFuncDef: Decodable {
let name: String
let description: String?
let parameters: [String: AnyCodable]?
}
let model: String?
let messages: [Message]
let stream: Bool?
let maxTokens: Int?
let temperature: Double?
let topP: Double?
let repetitionPenalty: Double?
let tools: [ToolDef]?

enum CodingKeys: String, CodingKey {
case model, messages, stream, temperature
case model, messages, stream, temperature, tools
case maxTokens = "max_tokens"
case topP = "top_p"
case repetitionPenalty = "repetition_penalty"
Expand Down Expand Up @@ -357,7 +442,45 @@ struct Choice: Encodable {

struct AssistantMessage: Encodable {
let role: String
let content: String
let content: String?
let toolCalls: [ToolCallResponse]?

enum CodingKeys: String, CodingKey {
case role, content
case toolCalls = "tool_calls"
}
}

struct ToolCallResponse: Encodable {
let id: String
let type: String
let function: ToolCallFunction
}

struct ToolCallFunction: Encodable {
let name: String
let arguments: String
}

/// AnyCodable: decode arbitrary JSON for tool parameters pass-through
struct AnyCodable: Decodable, Sendable {
let value: Any
init(from decoder: Decoder) throws {
let c = try decoder.singleValueContainer()
if c.decodeNil() { value = NSNull() }
else if let b = try? c.decode(Bool.self) { value = b }
else if let i = try? c.decode(Int.self) { value = i }
else if let d = try? c.decode(Double.self) { value = d }
else if let s = try? c.decode(String.self) { value = s }
else if let a = try? c.decode([AnyCodable].self) { value = a.map { $0.value } }
else if let d = try? c.decode([String: AnyCodable].self) { value = d.mapValues { $0.value } }
else { value = NSNull() }
}
// Convert back to [String: any Sendable] for ToolSpec usage
static func toSendable(_ dict: [String: AnyCodable]?) -> [String: any Sendable]? {
guard let dict else { return nil }
return dict.mapValues { $0.value as! any Sendable }
}
}

struct TokenUsage: Encodable {
Expand Down
Loading