-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfix_imports.py
More file actions
162 lines (135 loc) · 6.13 KB
/
fix_imports.py
File metadata and controls
162 lines (135 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""
Enhanced script to fix internal imports in the autogenerated LCM Python files.
This script extends fix_imports.py to specifically handle cases where classes
within the same package need to be imported explicitly.
"""
import os
import re
from pathlib import Path
# Message packages to update imports for
MSG_PACKAGES = [
"sensor_msgs",
"geometry_msgs",
"std_msgs",
"actionlib_msgs",
"builtin_interfaces",
"diagnostic_msgs",
"foxglove_msgs",
"nav_msgs",
"shape_msgs",
"stereo_msgs",
"tf2_msgs",
"trajectory_msgs",
"visualization_msgs",
]
def fix_file_imports(file_path):
"""Fix import statements in a file."""
with open(file_path, 'r') as f:
content = f.read()
updated = False
# Get the package name from the file path
pkg_name = None
for pkg in MSG_PACKAGES:
if pkg in str(file_path):
pkg_name = pkg
break
if not pkg_name:
return False
# Replace direct imports with relative imports
for pkg in MSG_PACKAGES:
# If the file imports a package directly (import pkg_name)
if pkg == pkg_name:
# Change 'import pkg_name' to 'from . import *'
pattern = fr'\nimport\s+{pkg}\s*\n'
if re.search(pattern, content):
content = re.sub(pattern, '\nfrom . import *\n', content)
updated = True
# Replace pkg_name.ClassName() with just ClassName()
pattern = fr'{pkg}\.\w+'
matches = re.findall(pattern, content)
for match in matches:
class_name = match.split('.')[1]
content = content.replace(match, class_name)
updated = True
else:
# For cross-package imports (e.g., geometry_msgs importing std_msgs)
pattern = fr'\nimport\s+{pkg}\s*\n'
if re.search(pattern, content):
content = re.sub(pattern, f'\nfrom lcm_msgs import {pkg}\n', content)
updated = True
# Additional fix: Add specific imports for referenced classes in the same package
# Extract the filename without extension to get the class name
current_class = os.path.basename(file_path).replace('.py', '')
# Get directory of the current file to find other module files
dir_path = os.path.dirname(file_path)
# Find all classes referenced in the file
class_pattern = r'self\.[\w\.]+ = ([A-Z][a-zA-Z0-9]+)\(\)'
referenced_classes = re.findall(class_pattern, content)
# Also check for class references in static methods
static_class_pattern = r'([A-Z][a-zA-Z0-9]+)\._(?:decode_one|get_hash_recursive|get_packed_fingerprint)'
referenced_classes.extend(re.findall(static_class_pattern, content))
# Make the list unique
referenced_classes = list(set(referenced_classes))
# Remove self-references
if current_class in referenced_classes:
referenced_classes.remove(current_class)
# Check if each referenced class exists as a file in the same directory
for class_name in referenced_classes:
class_file = os.path.join(dir_path, f"{class_name}.py")
if os.path.exists(class_file):
# Remove any existing import for this class that might be causing issues
existing_import_pattern = fr'from \. import {class_name}'
if re.search(existing_import_pattern, content):
content = re.sub(existing_import_pattern, '', content)
# Add correct import (from .ClassName import ClassName)
# Add import if not already present with the correct format
if f"from .{class_name} import {class_name}" not in content:
# Add after other imports but before class definition
import_line = f"from .{class_name} import {class_name}\n"
# Find a good spot to insert the import - after other imports but before class definition
import_section_end = 0
for match in re.finditer(r'import.*?\n', content):
end_pos = match.end()
if end_pos > import_section_end:
import_section_end = end_pos
# If we found imports, add after them, otherwise add near the top
if import_section_end > 0:
content = content[:import_section_end] + import_line + content[import_section_end:]
else:
# Add after docstring or at line 10 if no better spot found
docstring_end = content.find('"""', content.find('"""') + 3) + 3
if docstring_end > 6: # If we found the end of the docstring
content = content[:docstring_end] + "\n" + import_line + content[docstring_end:]
else:
# Insert around line 10 as a fallback
lines = content.split('\n')
insert_pos = min(10, len(lines) - 1)
lines.insert(insert_pos, import_line.strip())
content = '\n'.join(lines)
updated = True
print(f"Added import for {class_name} in {file_path}")
if updated:
print(f"Fixed imports in {file_path}")
with open(file_path, 'w') as f:
f.write(content)
return True
return False
def main():
"""Main function."""
# Find the lcm_msgs directory
script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
lcm_msgs_dir = script_dir / 'lcm_msgs'
if not lcm_msgs_dir.exists() or not lcm_msgs_dir.is_dir():
print(f"Error: lcm_msgs directory not found at {lcm_msgs_dir}")
return
# Find all Python files in the lcm_msgs directory
python_files = list(lcm_msgs_dir.glob('**/*.py'))
# Fix the import statements in each file
fixed_count = 0
for file_path in python_files:
if fix_file_imports(file_path):
fixed_count += 1
print(f"Fixed imports in {fixed_count} files.")
if __name__ == '__main__':
main()