-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfvecs_remove_zeros.py
More file actions
106 lines (76 loc) · 2.95 KB
/
fvecs_remove_zeros.py
File metadata and controls
106 lines (76 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
import argparse
import os
import struct
import numpy as np
def read_fvecs(fname):
fname = os.path.expanduser(fname)
data = np.fromfile(fname, dtype=np.float32)
if data.size == 0:
return np.empty((0, 0), dtype=np.float32)
dim = struct.unpack("<I", data[:1].tobytes())[0]
if dim <= 0:
raise ValueError(f"Invalid dimension {dim} in {fname}")
row_width = dim + 1
if data.size % row_width != 0:
raise ValueError(
f"File size is not consistent with fvecs format: "
f"{fname}, dim={dim}, float_count={data.size}"
)
data = data.reshape(-1, row_width)
dims = data[:, 0].view(np.int32)
if not np.all(dims == dim):
raise ValueError(f"Inconsistent vector dimensions in {fname}")
return np.ascontiguousarray(data[:, 1:], dtype=np.float32)
def write_fvecs(fname, arr):
arr = np.asarray(arr, dtype=np.float32)
if arr.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {arr.shape}")
n, d = arr.shape
fname = os.path.expanduser(fname)
d_repr = struct.unpack("<f", np.uint32(d))[0]
formatted = np.concatenate(
(np.full((n, 1), d_repr, dtype=np.float32), arr),
axis=1
)
if n > 0:
assert struct.unpack("<I", formatted[0, 0].tobytes()) == (d,)
with open(fname, "wb") as f:
formatted.tofile(f)
def count_zero_vectors(arr, tol=0.0):
norms = np.linalg.norm(arr, axis=1)
return int(np.sum(norms <= tol))
def remove_zero_vectors(arr, tol=0.0):
norms = np.linalg.norm(arr, axis=1)
keep_mask = norms > tol
return np.ascontiguousarray(arr[keep_mask], dtype=np.float32)
def main():
parser = argparse.ArgumentParser(
description="Remove vectors whose L2 norm is at or below a tolerance from an fvecs file."
)
parser.add_argument("--input", required=True, help="Input fvecs file")
parser.add_argument("--output", required=True, help="Output fvecs file with near-zero vectors removed")
parser.add_argument(
"--tolerance",
type=float,
default=0.0,
help="Remove vectors with L2 norm <= tolerance (default: 0.0)",
)
args = parser.parse_args()
if args.tolerance < 0:
raise ValueError("--tolerance must be non-negative")
vectors = read_fvecs(args.input)
zero_count = count_zero_vectors(vectors, tol=args.tolerance)
print(f"Zero tolerance: {args.tolerance}")
print(f"Zero-like vectors: {zero_count} / {vectors.shape[0]}")
cleaned = remove_zero_vectors(vectors, tol=args.tolerance)
if cleaned.shape[0] == 0:
raise ValueError("All vectors were zero after removal.")
removed = vectors.shape[0] - cleaned.shape[0]
print(f"Removed zero vectors: {removed}")
print(f"Remaining vectors: {cleaned.shape[0]}")
print(f"Dimension: {cleaned.shape[1]}")
write_fvecs(args.output, cleaned)
print(f"Wrote cleaned file to: {args.output}")
if __name__ == "__main__":
main()