Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 132 additions & 13 deletions comfyui_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
import random
import sys
import re
import keyword
from typing import Dict, List, Any, Callable, Tuple, TextIO

import black


from comfyui_to_python_utils import import_custom_nodes, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths,\
get_value_at_index, parse_arg, save_image_wrapper
get_value_at_index, parse_arg, save_image_wrapper, ensure_string

PACKAGED_FUNCTIONS = [
get_value_at_index,
ensure_string,
find_path,
add_comfyui_directory_to_sys_path,
add_extra_model_paths,
Expand Down Expand Up @@ -204,6 +206,9 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:

custom_nodes = False

# Map node id -> return types for downstream type-aware formatting
self._idx_to_return_types: Dict[str, Tuple] = {}

# Loop over each dictionary in the load order list
for idx, data, is_special_function in load_order:
# Generate class definition and inputs from the data
Expand Down Expand Up @@ -246,16 +251,62 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
if class_def_params is not None:
if 'unique_id' in class_def_params:
inputs['unique_id'] = random.randint(1, 2**64)
if 'prompt' in class_def_params:
# only inject PROMPT_DATA when node declares hidden prompt
try:
hidden_inputs = input_types.get('hidden', {}) if isinstance(input_types, dict) else {}
except Exception:
hidden_inputs = {}
if isinstance(hidden_inputs, dict) and 'prompt' in hidden_inputs:
inputs["prompt"] = {"variable_name": "PROMPT_DATA"}
include_prompt_data = True

# Create executed variable and generate code
executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}'
# Stash return types for this node for later type-aware formatting
try:
self._idx_to_return_types[idx] = getattr(self.node_class_mappings[class_type], 'RETURN_TYPES', tuple())
except Exception:
self._idx_to_return_types[idx] = tuple()
inputs = self.update_inputs(inputs, executed_variables)

# Store current class type for format_arg hotfixes
self._current_processing_class_type = class_type

# Infer schema generically from function signature default types
expected_string_keys = set()
expected_list_keys = set()
if class_def_params is not None:
# use node INPUT_TYPES metadata for defaults when available
try:
types_meta = self.node_class_mappings[class_type].INPUT_TYPES()
except Exception:
types_meta = {}
required_keys, optional_keys = set(), set()
def _merge(d, into_set):
for k, v in (d or {}).items():
# v may be (type_or_enum, options) or other
t = v[0] if isinstance(v, (list, tuple)) and v else v
# strings: STRICTLY "STRING" only (avoid misclassifying dynamic inputs like "*")
if t == 'STRING':
expected_string_keys.add(k)
# any/list typed inputs named suggestively
if k.lower() in ('anything', 'list', 'items'):
expected_list_keys.add(k)
into_set.add(k)
_merge(types_meta.get('required'), required_keys)
_merge(types_meta.get('optional'), optional_keys)
# If node declares INPUT_IS_LIST=True or per-key mapping, respect it
try:
ilist = getattr(self.node_class_mappings[class_type], 'INPUT_IS_LIST', None)
except Exception:
ilist = None
if ilist is True:
expected_list_keys.update(required_keys | optional_keys)
elif isinstance(ilist, dict):
expected_list_keys.update({k for k, v in ilist.items() if v})

if class_type == 'SaveImage':
save_code = self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs).strip()
save_code = self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys).strip()
return_code = ['if __name__ != "__main__":', '\treturn dict(' + ', '.join(self.format_arg(key, value) for key, value in inputs.items()) + ')', 'else:', '\t' + save_code]

if is_special_function:
Expand All @@ -264,16 +315,16 @@ def generate_workflow(self, load_order: List, queue_size: int = 1) -> str:
code.extend(return_code) ### This should presumably NEVER occur for a valid workflow
else:
if is_special_function:
special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs))
special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys))
else:
code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs))
code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, inputs, expected_string_keys, expected_list_keys))

# Generate final code by combining imports and code, and wrap them in a main function
final_code = self.assemble_python_code(import_statements, special_functions_code, arg_inputs, code, queue_size, custom_nodes, include_prompt_data)

return final_code

def create_function_call_code(self, obj_name: str, func: str, variable_name: str, is_special_function: bool, kwargs) -> str:
def create_function_call_code(self, obj_name: str, func: str, variable_name: str, is_special_function: bool, kwargs, expected_string_keys: set | None = None, expected_list_keys: set | None = None) -> str:
"""Generate Python code for a function call.

Args:
Expand All @@ -286,11 +337,39 @@ def create_function_call_code(self, obj_name: str, func: str, variable_name: str
Returns:
str: The generated Python code.
"""
args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items())
# If any kwarg key is not a valid python identifier or is a keyword,
# pass all kwargs via a dict expansion to ensure valid syntax
def _is_safe_identifier(name: Any) -> bool:
return isinstance(name, str) and name.isidentifier() and not keyword.iskeyword(name)

# Stash format context for this node
self._expected_string_keys = expected_string_keys or set()
self._expected_list_keys = expected_list_keys or set()
# Store current class type for hotfixes
self._current_class_type = getattr(self, '_current_processing_class_type', '')

use_dict_expansion = any(not _is_safe_identifier(k) for k in kwargs.keys())

if use_dict_expansion:
def _format_value(k: str, v: Any) -> str:
# Mirror the logic from format_arg but without attaching the key prefix
if isinstance(v, int) and (k == 'noise_seed' or k == 'seed'):
return 'random.randint(1, 2**64)'
elif isinstance(v, str):
return repr(v)
elif isinstance(v, dict) and 'variable_name' in v:
return v["variable_name"]
return repr(v) if isinstance(v, (list, dict, tuple, set)) else str(v)

dict_items = ', '.join(f'{repr(k)}: {_format_value(k, v)}' for k, v in kwargs.items())
args = f'**{{{dict_items}}}'
else:
args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items())

# Generate the Python code
code = f'{variable_name} = {obj_name}.{func}({args})\n'

# Clear context
self._expected_string_keys = set()
self._expected_list_keys = set()
return code

def format_arg(self, key: str, value: any) -> str:
Expand All @@ -309,7 +388,32 @@ def format_arg(self, key: str, value: any) -> str:
elif isinstance(value, str):
return f'{key}={repr(value)}'
elif isinstance(value, dict) and 'variable_name' in value:
return f'{key}={value["variable_name"]}'
# wrap variable references for string-like fields (only when expected for this node)
sa_type = value.get('__sa_type__') if isinstance(value, dict) else None
if hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()) and sa_type in (None, 'STRING'):
return f"{key}=ensure_string({value['variable_name']})"
# ensure list keys are wrapped as lists (ignore __sa_type__ constraint for list keys)
if hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()):
return f"{key}=[{value['variable_name']}]"
# HOTFIX: Force wrap 'anything' parameter for showAnything nodes as list
if key == 'anything' and hasattr(self, '_current_class_type') and 'showAnything' in getattr(self, '_current_class_type', ''):
return f"{key}=[{value['variable_name']}]"
return f"{key}={value['variable_name']}"
# unwrap {"__value__": X}
elif isinstance(value, dict) and '__value__' in value:
inner = value['__value__']
# if the target arg name is expected string for this node, coerce via ensure_string
if hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()):
return f"{key}=ensure_string({repr(inner)})"
if hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()) and not isinstance(inner, (list, tuple)):
return f"{key}=[{repr(inner)}]"
return f"{key}={repr(inner)}"
# for containers on string params, wrap with ensure_string to avoid list errors downstream
elif hasattr(self, '_expected_string_keys') and key in getattr(self, '_expected_string_keys', set()) and isinstance(value, (list, tuple, dict)):
return f"{key}=ensure_string({repr(value)})"
# ensure 'anything' input is iterable
elif hasattr(self, '_expected_list_keys') and key in getattr(self, '_expected_list_keys', set()) and not isinstance(value, (list, tuple)):
return f"{key}=[{repr(value)}]"
return f'{key}={value}'

def assemble_python_code(self, import_statements: set, special_functions_code: List[str], arg_inputs: List[Tuple[str, str]], code: List[str], queue_size: int, custom_nodes=False, include_prompt_data=True) -> str:
Expand Down Expand Up @@ -359,7 +463,7 @@ def assemble_python_code(self, import_statements: set, special_functions_code: L

# Define static import statements required for the script
static_imports = ['import os', 'import random', 'import sys', 'import json', 'import argparse', 'import contextlib', 'from typing import Sequence, Mapping, Any, Union',
'import torch'] + func_strings + argparse_code
'import torch', 'has_manager = False'] + func_strings + argparse_code
if include_prompt_data:
static_imports.append(f'PROMPT_DATA = json.loads({repr(json.dumps(self.prompt))})')
# Check if custom nodes should be included
Expand Down Expand Up @@ -483,9 +587,24 @@ def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict:
Returns:
Dict: Updated inputs dictionary.
"""
for key in inputs.keys():
for key in list(inputs.keys()):
# decode references [nodeId, index]
if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys():
inputs[key] = {'variable_name': f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})"}
src_idx, src_out = inputs[key][0], inputs[key][1]
ref = f"get_value_at_index({executed_variables[src_idx]}, {src_out})"
# If the referenced RETURN_TYPES indicates IMAGE/MASK/LATENT/MODEL/CLIP, mark as non-string, non-list
try:
out_types = self._idx_to_return_types.get(src_idx, tuple())
except Exception:
out_types = tuple()
val = {'variable_name': ref}
# Attach a soft hint for downstream formatting to not coerce
if out_types and isinstance(out_types, (list, tuple)) and src_out < len(out_types):
val['__sa_type__'] = out_types[src_out]
inputs[key] = val
# unwrap {"__value__": X} produced by some UIs
elif isinstance(inputs[key], dict) and '__value__' in inputs[key]:
inputs[key] = inputs[key]['__value__']
return inputs


Expand Down
13 changes: 13 additions & 0 deletions comfyui_to_python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
except KeyError:
return obj['result'][index]

def ensure_string(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, (list, tuple)):
try:
return ", ".join([v if isinstance(v, str) else str(v) for v in value])
except Exception:
return str(value)
try:
return str(value)
except Exception:
return json.dumps(value, ensure_ascii=False)

def parse_arg(s: Any, default: Any = None) -> Any:
""" Parses a JSON string, returning it unchanged if the parsing fails. """
if __name__ == "__main__" or not isinstance(s, str):
Expand Down