diff --git a/comfyui_to_python.py b/comfyui_to_python.py index 0006ca8..0ede4a6 100644 --- a/comfyui_to_python.py +++ b/comfyui_to_python.py @@ -6,16 +6,18 @@ import random import sys import re +import keyword from typing import Dict, List, Any, Callable, Tuple, TextIO import black from comfyui_to_python_utils import import_custom_nodes, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths,\ - get_value_at_index, parse_arg, save_image_wrapper + get_value_at_index, parse_arg, save_image_wrapper, ensure_string PACKAGED_FUNCTIONS = [ get_value_at_index, + ensure_string, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths, @@ -204,6 +206,9 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str: custom_nodes = False + # Map node id -> return types for downstream type-aware formatting + self._idx_to_return_types: Dict[str, Tuple] = {} + # Loop over each dictionary in the load order list for idx, data, is_special_function in load_order: # Generate class definition and inputs from the data @@ -246,16 +251,62 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str: if class_def_params is not None: if 'unique_id' in class_def_params: inputs['unique_id'] = random.randint(1, 2**64) - if 'prompt' in class_def_params: + # only inject PROMPT_DATA when node declares hidden prompt + try: + hidden_inputs = input_types.get('hidden', {}) if isinstance(input_types, dict) else {} + except Exception: + hidden_inputs = {} + if isinstance(hidden_inputs, dict) and 'prompt' in hidden_inputs: inputs["prompt"] = {"variable_name": "PROMPT_DATA"} include_prompt_data = True # Create executed variable and generate code executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}' + # Stash return types for this node for later type-aware formatting + try: + self._idx_to_return_types[idx] = getattr(self.node_class_mappings[class_type], 'RETURN_TYPES', tuple()) + except Exception: + self._idx_to_return_types[idx] = tuple() inputs = self.update_inputs(inputs, executed_variables) + # Store current class type for format_arg hotfixes + self._current_processing_class_type = class_type + + # Infer schema generically from function signature default types + expected_string_keys = set() + expected_list_keys = set() + if class_def_params is not None: + # use node INPUT_TYPES metadata for defaults when available + try: + types_meta = self.node_class_mappings[class_type].INPUT_TYPES() + except Exception: + types_meta = {} + required_keys, optional_keys = set(), set() + def _merge(d, into_set): + for k, v in (d or {}).items(): + # v may be (type_or_enum, options) or other + t = v[0] if isinstance(v, (list, tuple)) and v else v + # strings: STRICTLY "STRING" only (avoid misclassifying dynamic inputs like "*") + if t == 'STRING': + expected_string_keys.add(k) + # any/list typed inputs named suggestively + if k.lower() in ('anything', 'list', 'items'): + expected_list_keys.add(k) + into_set.add(k) + _merge(types_meta.get('required'), required_keys) + _merge(types_meta.get('optional'), optional_keys) + # If node declares INPUT_IS_LIST=True or per-key mapping, respect it + try: + ilist = getattr(self.node_class_mappings[class_type], 'INPUT_IS_LIST', None) + except Exception: + ilist = None + if ilist is True: + expected_list_keys.update(required_keys | optional_keys) + elif isinstance(ilist, dict): + expected_list_keys.update({k for k, v in ilist.items() if v}) + if class_type == 'SaveImage': - save_code = self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs).strip() + save_code = self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys).strip() return_code = ['if __name__ != "__main__":', '\treturn dict(' + ', '.join(self.format_arg(key, value) for key, value in inputs.items()) + ')', 'else:', '\t' + save_code] if is_special_function: @@ -264,16 +315,16 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str: code.extend(return_code) ### This should presumably NEVER occur for a valid workflow else: if is_special_function: - special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs)) + special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys)) else: - code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs)) + code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys)) # Generate final code by combining imports and code, and wrap them in a main function final_code = self.assemble_python_code(import_statements, special_functions_code, arg_inputs, code, queue_size, custom_nodes, include_prompt_data) return final_code - def create_function_call_code(self, obj_name: str, func: str, variable_name: str, is_special_function: bool, kwargs) -> str: + def create_function_call_code(self, obj_name: str, func: str, variable_name: str, is_special_function: bool, kwargs, expected_string_keys: set | None = None, expected_list_keys: set | None = None) -> str: """Generate Python code for a function call. Args: @@ -286,11 +337,39 @@ def create_function_call_code(self, obj_name: str, func: str, variable_name: str Returns: str: The generated Python code. """ - args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items()) + # If any kwarg key is not a valid python identifier or is a keyword, + # pass all kwargs via a dict expansion to ensure valid syntax + def _is_safe_identifier(name: Any) -> bool: + return isinstance(name, str) and name.isidentifier() and not keyword.iskeyword(name) + + # Stash format context for this node + self._expected_string_keys = expected_string_keys or set() + self._expected_list_keys = expected_list_keys or set() + # Store current class type for hotfixes + self._current_class_type = getattr(self, '_current_processing_class_type', '') + + use_dict_expansion = any(not _is_safe_identifier(k) for k in kwargs.keys()) + + if use_dict_expansion: + def _format_value(k: str, v: Any) -> str: + # Mirror the logic from format_arg but without attaching the key prefix + if isinstance(v, int) and (k == 'noise_seed' or k == 'seed'): + return 'random.randint(1, 2**64)' + elif isinstance(v, str): + return repr(v) + elif isinstance(v, dict) and 'variable_name' in v: + return v["variable_name"] + return repr(v) if isinstance(v, (list, dict, tuple, set)) else str(v) + + dict_items = ', '.join(f'{repr(k)}: {_format_value(k, v)}' for k, v in kwargs.items()) + args = f'**{{{dict_items}}}' + else: + args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items()) - # Generate the Python code code = f'{variable_name} = {obj_name}.{func}({args})\n' - + # Clear context + self._expected_string_keys = set() + self._expected_list_keys = set() return code def format_arg(self, key: str, value: any) -> str: @@ -309,7 +388,32 @@ def format_arg(self, key: str, value: any) -> str: elif isinstance(value, str): return f'{key}={repr(value)}' elif isinstance(value, dict) and 'variable_name' in value: - return f'{key}={value["variable_name"]}' + # wrap variable references for string-like fields (only when expected for this node) + sa_type = value.get('__sa_type__') if isinstance(value, dict) else None + if hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()) and sa_type in (None, 'STRING'): + return f"{key}=ensure_string({value['variable_name']})" + # ensure list keys are wrapped as lists (ignore __sa_type__ constraint for list keys) + if hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()): + return f"{key}=[{value['variable_name']}]" + # HOTFIX: Force wrap 'anything' parameter for showAnything nodes as list + if key == 'anything' and hasattr(self, '_current_class_type') and 'showAnything' in getattr(self, '_current_class_type', ''): + return f"{key}=[{value['variable_name']}]" + return f"{key}={value['variable_name']}" + # unwrap {"__value__": X} + elif isinstance(value, dict) and '__value__' in value: + inner = value['__value__'] + # if the target arg name is expected string for this node, coerce via ensure_string + if hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()): + return f"{key}=ensure_string({repr(inner)})" + if hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()) and not isinstance(inner, (list, tuple)): + return f"{key}=[{repr(inner)}]" + return f"{key}={repr(inner)}" + # for containers on string params, wrap with ensure_string to avoid list errors downstream + elif hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()) and isinstance(value, (list, tuple, dict)): + return f"{key}=ensure_string({repr(value)})" + # ensure 'anything' input is iterable + elif hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()) and not isinstance(value, (list, tuple)): + return f"{key}=[{repr(value)}]" return f'{key}={value}' def assemble_python_code(self, import_statements: set, special_functions_code: List[str], arg_inputs: List[Tuple[str, str]], code: List[str], queue_size: int, custom_nodes=False, include_prompt_data=True) -> str: @@ -359,7 +463,7 @@ def assemble_python_code(self, import_statements: set, special_functions_code: L # Define static import statements required for the script static_imports = ['import os', 'import random', 'import sys', 'import json', 'import argparse', 'import contextlib', 'from typing import Sequence, Mapping, Any, Union', - 'import torch'] + func_strings + argparse_code + 'import torch', 'has_manager = False'] + func_strings + argparse_code if include_prompt_data: static_imports.append(f'PROMPT_DATA = json.loads({repr(json.dumps(self.prompt))})') # Check if custom nodes should be included @@ -483,9 +587,24 @@ def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict: Returns: Dict: Updated inputs dictionary. """ - for key in inputs.keys(): + for key in list(inputs.keys()): + # decode references [nodeId, index] if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys(): - inputs[key] = {'variable_name': f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})"} + src_idx, src_out = inputs[key][0], inputs[key][1] + ref = f"get_value_at_index({executed_variables[src_idx]}, {src_out})" + # If the referenced RETURN_TYPES indicates IMAGE/MASK/LATENT/MODEL/CLIP, mark as non-string, non-list + try: + out_types = self._idx_to_return_types.get(src_idx, tuple()) + except Exception: + out_types = tuple() + val = {'variable_name': ref} + # Attach a soft hint for downstream formatting to not coerce + if out_types and isinstance(out_types, (list, tuple)) and src_out < len(out_types): + val['__sa_type__'] = out_types[src_out] + inputs[key] = val + # unwrap {"__value__": X} produced by some UIs + elif isinstance(inputs[key], dict) and '__value__' in inputs[key]: + inputs[key] = inputs[key]['__value__'] return inputs diff --git a/comfyui_to_python_utils.py b/comfyui_to_python_utils.py index 6700652..e00f10a 100644 --- a/comfyui_to_python_utils.py +++ b/comfyui_to_python_utils.py @@ -145,6 +145,19 @@ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: except KeyError: return obj['result'][index] +def ensure_string(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)): + try: + return ", ".join([v if isinstance(v, str) else str(v) for v in value]) + except Exception: + return str(value) + try: + return str(value) + except Exception: + return json.dumps(value, ensure_ascii=False) + def parse_arg(s: Any, default: Any = None) -> Any: """ Parses a JSON string, returning it unchanged if the parsing fails. """ if __name__ == "__main__" or not isinstance(s, str):