Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions mlx/io/safetensors.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2023 Apple Inc.
//

#include <json.hpp>
#include <memory>
#include <sstream>
#include <stack>

#include "mlx/backend/cuda/cuda.h"
Expand Down Expand Up @@ -97,8 +98,9 @@ Dtype dtype_from_safetensor_str(std::string_view str) {
} else if (str == ST_F8_E4M3) {
return uint8;
} else {
throw std::runtime_error(
"[safetensor] unsupported dtype " + std::string(str));
std::ostringstream msg;
msg << "[safetensor] unsupported dtype" << str;
throw std::runtime_error(msg.str());
}
}

Expand All @@ -109,8 +111,9 @@ SafetensorsLoad load_safetensors(
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error(
"[load_safetensors] Failed to open " + in_stream->label());
std::ostringstream msg;
msg << "[load_safetensors] Failed to open " << in_stream->label();
throw std::runtime_error(msg.str());
}

auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
Expand All @@ -120,17 +123,20 @@ SafetensorsLoad load_safetensors(
constexpr uint64_t kMaxJsonHeaderLength = 100000000;
in_stream->read(reinterpret_cast<char*>(&jsonHeaderLength), 8);
if (jsonHeaderLength <= 0 || jsonHeaderLength >= kMaxJsonHeaderLength) {
throw std::runtime_error(
"[load_safetensors] Invalid json header length " + in_stream->label());
std::ostringstream msg;
msg << "[load_safetensors] Invalid json header length "
<< in_stream->label();
throw std::runtime_error(msg.str());
}
// Load the json metadata
auto rawJson = std::make_unique<char[]>(jsonHeaderLength);
in_stream->read(rawJson.get(), jsonHeaderLength);
auto metadata = json::parse(rawJson.get(), rawJson.get() + jsonHeaderLength);
// Should always be an object on the top-level
if (!metadata.is_object()) {
throw std::runtime_error(
"[load_safetensors] Invalid json metadata " + in_stream->label());
std::ostringstream msg;
msg << "[load_safetensors] Invalid json metadata " << in_stream->label();
throw std::runtime_error(msg.str());
}
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
Expand All @@ -147,6 +153,28 @@ SafetensorsLoad load_safetensors(
const Shape& shape = item.value().at("shape");
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
Dtype type = dtype_from_safetensor_str(dtype);
if (data_offsets.size() != 2) {
std::ostringstream msg;
msg << "[load_safetensors] Tensor '" << item.key()
<< "' data_offsets must have exactly 2 entries but has "
<< data_offsets.size();
throw std::runtime_error(msg.str());
}
{
size_t expected_nbytes = type.size();
for (auto dim : shape) {
expected_nbytes *= static_cast<size_t>(dim);
}
if (data_offsets[1] < data_offsets[0] ||
data_offsets[1] - data_offsets[0] != expected_nbytes) {
std::ostringstream msg;
msg << "[load_safetensors] Tensor '" << item.key()
<< "' invalid data offsets (" << data_offsets[0] << ", "
<< data_offsets[1] << "). Expecting " << expected_nbytes
<< " bytes.";
throw std::runtime_error(msg.str());
}
}
res.insert(
{item.key(),
array(
Expand All @@ -170,8 +198,9 @@ void save_safetensors(
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error(
"[save_safetensors] Failed to open " + out_stream->label());
std::ostringstream msg;
msg << "[save_safetensors] Failed to open " << out_stream->label();
throw std::runtime_error(msg.str());
}

////////////////////////////////////////////////////////
Expand All @@ -196,8 +225,10 @@ void save_safetensors(
size_t offset = 0;
for (auto& [key, arr] : a) {
if (arr.nbytes() == 0) {
throw std::invalid_argument(
"[save_safetensors] cannot serialize an empty array key: " + key);
std::ostringstream msg;
msg << "[save_safetensors] Cannot serialize an empty array ('" << key
<< "')";
throw std::invalid_argument(msg.str());
}

json child;
Expand Down
62 changes: 62 additions & 0 deletions tests/load_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.

#include <filesystem>
#include <fstream>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -40,6 +41,67 @@ TEST_CASE("test save_safetensors") {
CHECK(array_equal(test2, ones({2, 2})).item<bool>());
}

TEST_CASE("test safetensors rejects mismatched data_offsets") {
// Build a minimal safetensors file where data_offsets claim 4 bytes
// but shape declares 1000x1000 float32 (4,000,000 bytes).
// Verifies that load_safetensors() catches the mismatch.
std::string file_path = get_temp_file("test_bad_offsets.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1000,1000],"data_offsets":[0,4]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
// Write only 4 bytes of data (the offsets claim [0,4])
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test safetensors rejects bad data_offsets count") {
// data_offsets has 3 entries instead of the required 2.
std::string file_path = get_temp_file("test_bad_offsets_count.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[0,4,8]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test safetensors rejects inverted data_offsets") {
// data_offsets[0] > data_offsets[1]
std::string file_path =
get_temp_file("test_bad_offsets_inverted.safetensors");

std::string header =
R"({"t":{"dtype":"F32","shape":[1],"data_offsets":[4,0]}})";
uint64_t header_len = header.size();

{
std::ofstream f(file_path, std::ios::binary);
f.write(reinterpret_cast<const char*>(&header_len), 8);
f.write(header.c_str(), header_len);
float one = 1.0f;
f.write(reinterpret_cast<const char*>(&one), sizeof(float));
}

CHECK_THROWS_AS(load_safetensors(file_path), std::runtime_error);
}

TEST_CASE("test gguf") {
std::string file_path = get_temp_file("test_arr.gguf");
using dict = std::unordered_map<std::string, array>;
Expand Down
Loading