-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_processing.py
More file actions
465 lines (365 loc) · 16.1 KB
/
batch_processing.py
File metadata and controls
465 lines (365 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
#!/usr/bin/env python3
"""
Scintirete SDK 批处理示例
演示如何使用 Scintirete Python SDK 进行大规模数据的批处理操作。
"""
import random
import time
import asyncio
from typing import List, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from scintirete_sdk import (
ScintireteClient,
ScintireteAsyncClient,
DistanceMetric,
HnswConfig,
Vector,
TextWithMetadata,
ScintireteError,
)
def generate_batch_vectors(start_id: int, batch_size: int, dimension: int = 128) -> List[Vector]:
"""生成批量向量数据"""
vectors = []
for i in range(batch_size):
vector = Vector(
elements=[random.random() for _ in range(dimension)],
metadata={
"batch_id": start_id // batch_size,
"item_id": start_id + i,
"category": f"category_{(start_id + i) % 10}",
"timestamp": time.time(),
"batch_info": f"batch_{start_id}-{start_id + batch_size - 1}"
}
)
vectors.append(vector)
return vectors
def batch_data_generator(total_size: int, batch_size: int, dimension: int = 128) -> Iterator[List[Vector]]:
"""批量数据生成器"""
for start_id in range(0, total_size, batch_size):
current_batch_size = min(batch_size, total_size - start_id)
yield generate_batch_vectors(start_id, current_batch_size, dimension)
def demonstrate_sync_batch_insert(client: ScintireteClient, db_name: str, collection_name: str):
"""演示同步批量插入"""
print("\n📥 同步批量插入演示")
print("=" * 50)
total_vectors = 1000
batch_size = 100
total_batches = (total_vectors + batch_size - 1) // batch_size
print(f"准备插入 {total_vectors} 个向量,分 {total_batches} 批处理")
print(f"每批大小: {batch_size}")
start_time = time.time()
total_inserted = 0
all_inserted_ids = []
try:
for batch_num, vectors in enumerate(batch_data_generator(total_vectors, batch_size)):
batch_start_time = time.time()
# 插入当前批次
inserted_ids, inserted_count = client.insert_vectors(
db_name=db_name,
collection_name=collection_name,
vectors=vectors
)
batch_duration = time.time() - batch_start_time
total_inserted += inserted_count
all_inserted_ids.extend(inserted_ids)
print(f" 批次 {batch_num + 1}/{total_batches}: 插入 {inserted_count} 个向量,耗时 {batch_duration:.3f}s")
total_duration = time.time() - start_time
print(f"\n✅ 同步批量插入完成:")
print(f" 总向量数: {total_inserted}")
print(f" 总耗时: {total_duration:.3f}s")
print(f" 平均速度: {total_inserted / total_duration:.1f} vectors/s")
return all_inserted_ids
except ScintireteError as e:
print(f"❌ 同步批量插入失败: {e}")
return []
async def demonstrate_async_batch_insert(client: ScintireteAsyncClient, db_name: str, collection_name: str):
"""演示异步批量插入"""
print("\n🚀 异步批量插入演示")
print("=" * 50)
total_vectors = 1000
batch_size = 100
max_concurrent = 5 # 最大并发批次数
print(f"准备异步插入 {total_vectors} 个向量")
print(f"每批大小: {batch_size}, 最大并发: {max_concurrent}")
start_time = time.time()
total_inserted = 0
all_inserted_ids = []
async def insert_batch(batch_num: int, vectors: List[Vector]):
"""异步插入单个批次"""
batch_start_time = time.time()
try:
inserted_ids, inserted_count = await client.insert_vectors(
db_name=db_name,
collection_name=collection_name,
vectors=vectors
)
batch_duration = time.time() - batch_start_time
print(f" 批次 {batch_num}: 插入 {inserted_count} 个向量,耗时 {batch_duration:.3f}s")
return inserted_ids, inserted_count
except Exception as e:
print(f" 批次 {batch_num} 失败: {e}")
return [], 0
try:
# 创建批次任务
batches = list(batch_data_generator(total_vectors, batch_size))
# 使用信号量控制并发数
semaphore = asyncio.Semaphore(max_concurrent)
async def controlled_insert(batch_num: int, vectors: List[Vector]):
async with semaphore:
return await insert_batch(batch_num, vectors)
# 创建所有任务
tasks = [
controlled_insert(i, batch_vectors)
for i, batch_vectors in enumerate(batches)
]
# 等待所有任务完成
results = await asyncio.gather(*tasks)
# 汇总结果
for inserted_ids, inserted_count in results:
total_inserted += inserted_count
all_inserted_ids.extend(inserted_ids)
total_duration = time.time() - start_time
print(f"\n✅ 异步批量插入完成:")
print(f" 总向量数: {total_inserted}")
print(f" 总耗时: {total_duration:.3f}s")
print(f" 平均速度: {total_inserted / total_duration:.1f} vectors/s")
return all_inserted_ids
except Exception as e:
print(f"❌ 异步批量插入失败: {e}")
return []
def demonstrate_concurrent_search(client: ScintireteClient, db_name: str, collection_name: str, num_queries: int = 50):
"""演示并发搜索(使用线程池)"""
print("\n🔍 并发搜索演示(线程池)")
print("=" * 50)
print(f"准备执行 {num_queries} 个并发搜索查询")
# 生成查询向量
query_vectors = [
[random.random() for _ in range(128)]
for _ in range(num_queries)
]
def search_single(query_id: int, query_vector: List[float]):
"""单个搜索任务"""
start_time = time.time()
try:
results = client.search(
db_name=db_name,
collection_name=collection_name,
query_vector=query_vector,
top_k=10
)
duration = time.time() - start_time
return query_id, len(results), duration, None
except Exception as e:
duration = time.time() - start_time
return query_id, 0, duration, str(e)
start_time = time.time()
# 使用线程池执行并发搜索
with ThreadPoolExecutor(max_workers=10) as executor:
# 提交所有任务
future_to_query = {
executor.submit(search_single, i, query_vectors[i]): i
for i in range(num_queries)
}
# 收集结果
completed_queries = 0
total_results = 0
successful_queries = 0
for future in as_completed(future_to_query):
query_id, result_count, duration, error = future.result()
completed_queries += 1
if error is None:
successful_queries += 1
total_results += result_count
print(f" 查询 {query_id}: {result_count} 个结果,耗时 {duration:.3f}s")
else:
print(f" 查询 {query_id}: 失败 - {error}")
total_duration = time.time() - start_time
print(f"\n✅ 并发搜索完成:")
print(f" 总查询数: {completed_queries}")
print(f" 成功查询: {successful_queries}")
print(f" 总结果数: {total_results}")
print(f" 总耗时: {total_duration:.3f}s")
print(f" 平均查询时间: {total_duration / num_queries:.3f}s")
print(f" 查询吞吐量: {num_queries / total_duration:.1f} queries/s")
def demonstrate_batch_delete(client: ScintireteClient, db_name: str, collection_name: str, vector_ids: List[int]):
"""演示批量删除"""
print("\n🗑️ 批量删除演示")
print("=" * 50)
if not vector_ids:
print("没有可删除的向量ID")
return
batch_size = 100
total_deleted = 0
print(f"准备删除 {len(vector_ids)} 个向量,每批 {batch_size} 个")
start_time = time.time()
# 分批删除
for i in range(0, len(vector_ids), batch_size):
batch_ids = vector_ids[i:i + batch_size]
batch_start_time = time.time()
try:
deleted_count = client.delete_vectors(
db_name=db_name,
collection_name=collection_name,
ids=batch_ids
)
batch_duration = time.time() - batch_start_time
total_deleted += deleted_count
batch_num = i // batch_size + 1
total_batches = (len(vector_ids) + batch_size - 1) // batch_size
print(f" 批次 {batch_num}/{total_batches}: 删除 {deleted_count} 个向量,耗时 {batch_duration:.3f}s")
except ScintireteError as e:
print(f" 批次删除失败: {e}")
total_duration = time.time() - start_time
print(f"\n✅ 批量删除完成:")
print(f" 删除向量数: {total_deleted}")
print(f" 总耗时: {total_duration:.3f}s")
def demonstrate_performance_monitoring(client: ScintireteClient, db_name: str, collection_name: str):
"""演示性能监控"""
print("\n📊 性能监控演示")
print("=" * 50)
# 获取集合信息
info = client.get_collection_info(db_name, collection_name)
print(f"集合统计信息:")
print(f" 向量总数: {info.vector_count:,}")
print(f" 已删除数: {info.deleted_count:,}")
print(f" 有效向量: {info.vector_count - info.deleted_count:,}")
print(f" 内存使用: {info.memory_bytes / 1024 / 1024:.2f} MB")
print(f" 向量维度: {info.dimension}")
print(f" 距离度量: {info.metric_type}")
# 性能测试:不同 top_k 值的搜索性能
print(f"\n不同 top_k 值的搜索性能测试:")
query_vector = [random.random() for _ in range(info.dimension)]
for top_k in [1, 5, 10, 50, 100]:
start_time = time.time()
try:
results = client.search(
db_name=db_name,
collection_name=collection_name,
query_vector=query_vector,
top_k=top_k
)
duration = time.time() - start_time
print(f" top_k={top_k:3d}: {len(results):3d} 个结果,耗时 {duration:.4f}s")
except ScintireteError as e:
print(f" top_k={top_k:3d}: 搜索失败 - {e}")
async def main():
"""主异步函数"""
# 连接配置
SERVER_ADDRESS = "localhost:50051"
PASSWORD = None
# 测试数据配置
DB_NAME = "batch_processing_demo"
COLLECTION_NAME = "batch_vectors"
try:
print("🔗 批处理演示开始...")
# 1. 同步客户端批处理演示
print("\n" + "="*60)
print("同步客户端批处理演示")
print("="*60)
with ScintireteClient(SERVER_ADDRESS, password=PASSWORD) as sync_client:
# 准备测试环境
try:
sync_client.drop_database(DB_NAME)
except:
pass
sync_client.create_database(DB_NAME)
sync_client.create_collection(
db_name=DB_NAME,
collection_name=COLLECTION_NAME,
metric_type=DistanceMetric.COSINE,
hnsw_config=HnswConfig(m=16, ef_construction=200)
)
# 同步批量插入
sync_inserted_ids = demonstrate_sync_batch_insert(sync_client, DB_NAME, COLLECTION_NAME)
# 等待索引构建
print("\n⏳ 等待索引构建...")
time.sleep(3)
# 并发搜索
demonstrate_concurrent_search(sync_client, DB_NAME, COLLECTION_NAME)
# 性能监控
demonstrate_performance_monitoring(sync_client, DB_NAME, COLLECTION_NAME)
# 批量删除
if sync_inserted_ids:
# 删除一半的向量
ids_to_delete = sync_inserted_ids[:len(sync_inserted_ids)//2]
demonstrate_batch_delete(sync_client, DB_NAME, COLLECTION_NAME, ids_to_delete)
# 2. 异步客户端批处理演示
print("\n" + "="*60)
print("异步客户端批处理演示")
print("="*60)
async with ScintireteAsyncClient(SERVER_ADDRESS, password=PASSWORD) as async_client:
# 清理并重新创建
try:
await async_client.drop_database(DB_NAME)
except:
pass
await async_client.create_database(DB_NAME)
await async_client.create_collection(
db_name=DB_NAME,
collection_name=COLLECTION_NAME,
metric_type=DistanceMetric.COSINE,
hnsw_config=HnswConfig(m=16, ef_construction=200)
)
# 异步批量插入
async_inserted_ids = await demonstrate_async_batch_insert(async_client, DB_NAME, COLLECTION_NAME)
# 等待索引构建
print("\n⏳ 等待索引构建...")
await asyncio.sleep(3)
# 异步并发搜索
print("\n🚀 异步并发搜索演示")
print("=" * 50)
num_queries = 100
query_vectors = [
[random.random() for _ in range(128)]
for _ in range(num_queries)
]
async def async_search_single(query_id: int, query_vector: List[float]):
start_time = time.time()
try:
results = await async_client.search(
db_name=DB_NAME,
collection_name=COLLECTION_NAME,
query_vector=query_vector,
top_k=10
)
duration = time.time() - start_time
return query_id, len(results), duration, None
except Exception as e:
duration = time.time() - start_time
return query_id, 0, duration, str(e)
start_time = time.time()
# 控制并发数的异步搜索
semaphore = asyncio.Semaphore(20)
async def controlled_search(query_id: int, query_vector: List[float]):
async with semaphore:
return await async_search_single(query_id, query_vector)
tasks = [
controlled_search(i, query_vectors[i])
for i in range(num_queries)
]
results = await asyncio.gather(*tasks)
total_duration = time.time() - start_time
successful = sum(1 for _, _, _, error in results if error is None)
total_results = sum(count for _, count, _, error in results if error is None)
print(f"✅ 异步并发搜索完成:")
print(f" 总查询数: {num_queries}")
print(f" 成功查询: {successful}")
print(f" 总结果数: {total_results}")
print(f" 总耗时: {total_duration:.3f}s")
print(f" 查询吞吐量: {num_queries / total_duration:.1f} queries/s")
# 清理
sync_client = ScintireteClient(SERVER_ADDRESS, password=PASSWORD)
try:
sync_client.drop_database(DB_NAME)
print(f"\n🧹 清理测试数据库: {DB_NAME}")
except:
pass
finally:
sync_client.close()
print("\n🎉 批处理演示完成!")
except Exception as e:
print(f"❌ 批处理演示失败: {e}")
raise
if __name__ == "__main__":
# 运行异步主函数
asyncio.run(main())