-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathprocess_images.py
More file actions
439 lines (375 loc) · 17.1 KB
/
process_images.py
File metadata and controls
439 lines (375 loc) · 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
#!/usr/bin/env python3
"""
Image processing script for OCR and entity extraction using OpenAI-compatible API.
Processes images from Downloads folder and extracts structured data.
"""
import os
import json
import re
import base64
from pathlib import Path
from typing import Dict, List, Optional
import concurrent.futures
from dataclasses import dataclass, asdict
from openai import OpenAI
from tqdm import tqdm
import argparse
from dotenv import load_dotenv
@dataclass
class ProcessingResult:
"""Structure for processing results"""
filename: str
success: bool
data: Optional[Dict] = None
error: Optional[str] = None
class ImageProcessor:
"""Process images using OpenAI-compatible vision API"""
def __init__(self, api_url: str, api_key: str, model: str = "gpt-4o", index_file: str = "processing_index.json", downloads_dir: Optional[str] = None):
self.client = OpenAI(api_key=api_key, base_url=api_url)
self.model = model
self.downloads_dir = Path(downloads_dir) if downloads_dir else Path.home() / "Downloads"
self.index_file = index_file
self.processed_files = self.load_index()
def load_index(self) -> set:
"""Load the index of already processed files"""
if os.path.exists(self.index_file):
try:
with open(self.index_file, 'r') as f:
data = json.load(f)
return set(data.get('processed_files', []))
except Exception as e:
print(f"⚠️ Warning: Could not load index file: {e}")
return set()
return set()
def save_index(self, failed_files=None):
"""Save the current index of processed files"""
data = {
'processed_files': sorted(list(self.processed_files)),
'last_updated': str(Path.cwd())
}
if failed_files:
data['failed_files'] = failed_files
with open(self.index_file, 'w') as f:
json.dump(data, f, indent=2)
def mark_processed(self, filename: str):
"""Mark a file as processed and update index"""
self.processed_files.add(filename)
self.save_index()
def get_image_files(self) -> List[Path]:
"""Get all image files from Downloads folder (recursively)"""
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
image_files = []
for ext in image_extensions:
image_files.extend(self.downloads_dir.glob(f'**/*{ext}'))
image_files.extend(self.downloads_dir.glob(f'**/*{ext.upper()}'))
return sorted(image_files)
def get_relative_path(self, file_path: Path) -> str:
"""Get relative path from downloads directory for unique indexing"""
try:
return str(file_path.relative_to(self.downloads_dir))
except ValueError:
# If file is not relative to downloads_dir, use full path
return str(file_path)
def get_unprocessed_files(self) -> List[Path]:
"""Get only files that haven't been processed yet"""
all_files = self.get_image_files()
unprocessed = [f for f in all_files if self.get_relative_path(f) not in self.processed_files]
return unprocessed
def encode_image(self, image_path: Path) -> str:
"""Encode image to base64"""
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
def get_system_prompt(self) -> str:
"""Get the system prompt for structured extraction"""
return """You are an expert OCR and document analysis system.
Extract ALL text from the image in READING ORDER to create a digital twin of the document.
IMPORTANT: Transcribe text exactly as it appears on the page, from top to bottom, left to right, including:
- All printed text
- All handwritten text (inline where it appears)
- Stamps and annotations (inline where they appear)
- Signatures (note location)
Preserve the natural reading flow. Mix printed and handwritten text together in the order they appear.
Return ONLY valid JSON in this exact structure:
{
"document_metadata": {
"page_number": "string or null",
"document_number": "string or null",
"date": "string or null",
"document_type": "string or null",
"has_handwriting": true/false,
"has_stamps": true/false
},
"full_text": "Complete text transcription in reading order. Include ALL text - printed, handwritten, stamps, etc. - exactly as it appears from top to bottom.",
"text_blocks": [
{
"type": "printed|handwritten|stamp|signature|other",
"content": "text content",
"position": "top|middle|bottom|header|footer|margin"
}
],
"entities": {
"people": ["list of person names"],
"organizations": ["list of organizations"],
"locations": ["list of locations"],
"dates": ["list of dates found"],
"reference_numbers": ["list of any reference/ID numbers"]
},
"additional_notes": "Any observations about document quality, redactions, damage, etc."
}"""
def fix_json_with_llm(self, base64_image: str, broken_json: str, error_msg: str) -> dict:
"""Ask the LLM to fix its own broken JSON"""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": self.get_system_prompt()
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Extract all text and entities from this image. Return only valid JSON."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
},
{
"role": "assistant",
"content": broken_json
},
{
"role": "user",
"content": f"Your JSON response has an error: {error_msg}\n\nPlease fix the JSON and return ONLY the corrected valid JSON. Do not explain, just return the fixed JSON."
}
],
max_tokens=4096,
temperature=0.1
)
content = response.choices[0].message.content.strip()
# Extract JSON using same logic
json_match = re.search(r'```(?:json)?\s*\n(.*?)\n```', content, re.DOTALL)
if json_match:
content = json_match.group(1).strip()
else:
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
content = json_match.group(0).strip()
return json.loads(content)
def process_image(self, image_path: Path) -> ProcessingResult:
"""Process a single image through the API"""
try:
# Encode image
base64_image = self.encode_image(image_path)
# Make API call using OpenAI client
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": self.get_system_prompt()
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Extract all text and entities from this image. Return only valid JSON."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=4096,
temperature=0.1
)
# Parse response
content = response.choices[0].message.content
original_content = content # Keep original for retry
# Robust JSON extraction
content = content.strip()
# 1. Try to find JSON between markdown code fences
json_match = re.search(r'```(?:json)?\s*\n(.*?)\n```', content, re.DOTALL)
if json_match:
content = json_match.group(1).strip()
else:
# 2. Try to find JSON between curly braces
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
content = json_match.group(0).strip()
else:
# 3. Strip markdown manually
if content.startswith('```json'):
content = content[7:]
elif content.startswith('```'):
content = content[3:]
if content.endswith('```'):
content = content[:-3]
content = content.strip()
# Try to parse JSON
try:
extracted_data = json.loads(content)
except json.JSONDecodeError as e:
# Try to salvage by finding the first complete JSON object
try:
# Find first { and matching }
start = content.find('{')
if start == -1:
raise ValueError("No JSON object found")
brace_count = 0
end = start
for i in range(start, len(content)):
if content[i] == '{':
brace_count += 1
elif content[i] == '}':
brace_count -= 1
if brace_count == 0:
end = i + 1
break
if end > start:
content = content[start:end]
extracted_data = json.loads(content)
else:
raise ValueError("Could not find complete JSON object")
except Exception:
# Last resort: Ask LLM to fix its JSON
try:
extracted_data = self.fix_json_with_llm(base64_image, original_content, str(e))
except Exception:
# Save ORIGINAL LLM response to errors directory (not our extracted version)
self.save_broken_json(self.get_relative_path(image_path), original_content)
# If even that fails, raise the original error
raise e
return ProcessingResult(
filename=self.get_relative_path(image_path),
success=True,
data=extracted_data
)
except Exception as e:
return ProcessingResult(
filename=self.get_relative_path(image_path),
success=False,
error=str(e)
)
def process_all(self, max_workers: int = 5, limit: Optional[int] = None, resume: bool = True) -> List[ProcessingResult]:
"""Process all images with parallel processing"""
if resume:
image_files = self.get_unprocessed_files()
total_files = len(self.get_image_files())
already_processed = len(self.processed_files)
print(f"Found {total_files} total image files")
print(f"Already processed: {already_processed}")
print(f"Remaining to process: {len(image_files)}")
else:
image_files = self.get_image_files()
print(f"Found {len(image_files)} image files to process")
if limit:
image_files = image_files[:limit]
print(f"Limited to {limit} files for this run")
if not image_files:
print("No files to process!")
return []
results = []
failed_files = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(self.process_image, img): img for img in image_files}
with tqdm(total=len(image_files), desc="Processing images") as pbar:
for future in concurrent.futures.as_completed(futures):
result = future.result()
results.append(result)
# Save individual result to file
if result.success:
self.save_individual_result(result)
tqdm.write(f"✅ Processed: {result.filename}")
else:
# Track failed files
failed_files.append({
'filename': result.filename,
'error': result.error
})
tqdm.write(f"❌ Failed: {result.filename} - {result.error}")
# Mark as processed regardless of success/failure
self.mark_processed(result.filename)
pbar.update(1)
# Save failed files to index for reference
if failed_files:
self.save_index(failed_files=failed_files)
print(f"\n⚠️ {len(failed_files)} files failed - logged in {self.index_file}")
return results
def save_individual_result(self, result: ProcessingResult):
"""Save individual result to ./results/folder/imagename.json"""
# Create output path mirroring the source structure
result_path = Path("./results") / result.filename
result_path = result_path.with_suffix('.json')
# Create parent directories
result_path.parent.mkdir(parents=True, exist_ok=True)
# Save the extracted data
with open(result_path, 'w', encoding='utf-8') as f:
json.dump(result.data, f, indent=2, ensure_ascii=False)
def save_broken_json(self, filename: str, broken_content: str):
"""Save broken JSON to errors directory"""
error_path = Path("./errors") / filename
error_path = error_path.with_suffix('.json')
# Create parent directories
error_path.parent.mkdir(parents=True, exist_ok=True)
# Save the broken content as-is
with open(error_path, 'w', encoding='utf-8') as f:
f.write(broken_content)
def save_results(self, results: List[ProcessingResult], output_file: str = "processed_results.json"):
"""Save summary results to JSON file"""
output_data = {
"total_processed": len(results),
"successful": sum(1 for r in results if r.success),
"failed": sum(1 for r in results if not r.success),
"results": [asdict(r) for r in results]
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"\n✅ Summary saved to {output_file}")
print(f" Individual results saved to ./results/")
print(f" Successful: {output_data['successful']}")
print(f" Failed: {output_data['failed']}")
def main():
# Load environment variables
load_dotenv()
parser = argparse.ArgumentParser(description="Process images with OCR and entity extraction")
parser.add_argument("--api-url", help="OpenAI-compatible API base URL (default: from .env or OPENAI_API_URL)")
parser.add_argument("--api-key", help="API key (default: from .env or OPENAI_API_KEY)")
parser.add_argument("--model", help="Model name (default: from .env, OPENAI_MODEL, or meta-llama/Llama-4-Maverick-17B-128E-Instruct)")
parser.add_argument("--workers", type=int, default=5, help="Number of parallel workers (default: 5)")
parser.add_argument("--limit", type=int, help="Limit number of images to process (for testing)")
parser.add_argument("--output", default="processed_results.json", help="Output JSON file")
parser.add_argument("--index", default="processing_index.json", help="Index file to track processed files")
parser.add_argument("--downloads-dir", default="./downloads", help="Directory containing images (default: ./downloads)")
parser.add_argument("--no-resume", action="store_true", help="Process all files, ignoring index")
args = parser.parse_args()
# Get values from args or environment variables
api_url = args.api_url or os.getenv("OPENAI_API_URL", "http://...")
api_key = args.api_key or os.getenv("OPENAI_API_KEY", "abcd1234")
model = args.model or os.getenv("OPENAI_MODEL", "meta-llama/Llama-4-Maverick-17B-128E-Instruct")
processor = ImageProcessor(
api_url=api_url,
api_key=api_key,
model=model,
index_file=args.index,
downloads_dir=args.downloads_dir
)
results = processor.process_all(
max_workers=args.workers,
limit=args.limit,
resume=not args.no_resume
)
processor.save_results(results, args.output)
if __name__ == "__main__":
main()