@@ -239,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
239239 test_path .write_text (modified_module .code , encoding = "utf-8" )
240240
241241
242- class OptimFunctionCollector (cst .CSTVisitor ):
243- METADATA_DEPENDENCIES = (cst .metadata .ParentNodeProvider ,)
244-
245- def __init__ (
246- self ,
247- preexisting_objects : set [tuple [str , tuple [FunctionParent , ...]]] | None = None ,
248- function_names : set [tuple [str | None , str ]] | None = None ,
249- ) -> None :
250- super ().__init__ ()
251- self .preexisting_objects = preexisting_objects if preexisting_objects is not None else set ()
252-
253- self .function_names = function_names # set of (class_name, function_name)
254- self .modified_functions : dict [
255- tuple [str | None , str ], cst .FunctionDef
256- ] = {} # keys are (class_name, function_name)
257- self .new_functions : list [cst .FunctionDef ] = []
258- self .new_class_functions : dict [str , list [cst .FunctionDef ]] = defaultdict (list )
259- self .new_classes : list [cst .ClassDef ] = []
260- self .current_class = None
261- self .modified_init_functions : dict [str , cst .FunctionDef ] = {}
262-
263- def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
264- if (self .current_class , node .name .value ) in self .function_names :
265- self .modified_functions [(self .current_class , node .name .value )] = node
266- elif self .current_class and node .name .value == "__init__" :
267- self .modified_init_functions [self .current_class ] = node
268- elif (
269- self .preexisting_objects
270- and (node .name .value , ()) not in self .preexisting_objects
271- and self .current_class is None
272- ):
273- self .new_functions .append (node )
274- return False
275-
276- def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
277- if self .current_class :
278- return False # If already in a class, do not recurse deeper
279- self .current_class = node .name .value
280-
281- parents = (FunctionParent (name = node .name .value , type = "ClassDef" ),)
282-
283- if (node .name .value , ()) not in self .preexisting_objects :
284- self .new_classes .append (node )
285-
286- for child_node in node .body .body :
287- if (
288- self .preexisting_objects
289- and isinstance (child_node , cst .FunctionDef )
290- and (child_node .name .value , parents ) not in self .preexisting_objects
291- ):
292- self .new_class_functions [node .name .value ].append (child_node )
293-
294- return True
295-
296- def leave_ClassDef (self , node : cst .ClassDef ) -> None :
297- if self .current_class :
298- self .current_class = None
299-
300-
301- class OptimFunctionReplacer (cst .CSTTransformer ):
302- def __init__ (
303- self ,
304- modified_functions : Optional [dict [tuple [str | None , str ], cst .FunctionDef ]] = None ,
305- new_classes : Optional [list [cst .ClassDef ]] = None ,
306- new_functions : Optional [list [cst .FunctionDef ]] = None ,
307- new_class_functions : Optional [dict [str , list [cst .FunctionDef ]]] = None ,
308- modified_init_functions : Optional [dict [str , cst .FunctionDef ]] = None ,
309- ) -> None :
310- super ().__init__ ()
311- self .modified_functions = modified_functions if modified_functions is not None else {}
312- self .new_functions = new_functions if new_functions is not None else []
313- self .new_classes = new_classes if new_classes is not None else []
314- self .new_class_functions = new_class_functions if new_class_functions is not None else defaultdict (list )
315- self .modified_init_functions : dict [str , cst .FunctionDef ] = (
316- modified_init_functions if modified_init_functions is not None else {}
317- )
318- self .current_class = None
319-
320- def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
321- return False
322-
323- def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
324- if (self .current_class , original_node .name .value ) in self .modified_functions :
325- node = self .modified_functions [(self .current_class , original_node .name .value )]
326- return updated_node .with_changes (body = node .body , decorators = node .decorators )
327- if original_node .name .value == "__init__" and self .current_class in self .modified_init_functions :
328- return self .modified_init_functions [self .current_class ]
329-
330- return updated_node
331-
332- def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
333- if self .current_class :
334- return False # If already in a class, do not recurse deeper
335- self .current_class = node .name .value
336- return True
337-
338- def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef :
339- if self .current_class and self .current_class == original_node .name .value :
340- self .current_class = None
341- if original_node .name .value in self .new_class_functions :
342- return updated_node .with_changes (
343- body = updated_node .body .with_changes (
344- body = (list (updated_node .body .body ) + list (self .new_class_functions [original_node .name .value ]))
345- )
346- )
347- return updated_node
348-
349- def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
350- node = updated_node
351- max_function_index = None
352- max_class_index = None
353- for index , _node in enumerate (node .body ):
354- if isinstance (_node , cst .FunctionDef ):
355- max_function_index = index
356- if isinstance (_node , cst .ClassDef ):
357- max_class_index = index
358-
359- if self .new_classes :
360- existing_class_names = {_node .name .value for _node in node .body if isinstance (_node , cst .ClassDef )}
361-
362- unique_classes = [
363- new_class for new_class in self .new_classes if new_class .name .value not in existing_class_names
364- ]
365- if unique_classes :
366- new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports (node )
367- new_body = list (
368- chain (node .body [:new_classes_insertion_idx ], unique_classes , node .body [new_classes_insertion_idx :])
369- )
370- node = node .with_changes (body = new_body )
371-
372- if max_function_index is not None :
373- node = node .with_changes (
374- body = (* node .body [: max_function_index + 1 ], * self .new_functions , * node .body [max_function_index + 1 :])
375- )
376- elif max_class_index is not None :
377- node = node .with_changes (
378- body = (* node .body [: max_class_index + 1 ], * self .new_functions , * node .body [max_class_index + 1 :])
379- )
380- else :
381- node = node .with_changes (body = (* self .new_functions , * node .body ))
382- return node
383-
384-
385242def replace_functions_in_file (
386243 source_code : str ,
387244 original_function_names : list [str ],
0 commit comments