Skip to content

Conversation

@rxu183
Copy link
Contributor

@rxu183 rxu183 commented Feb 3, 2026

Changes:

  • Created 1.3B WAN config yaml. Removed the LoRA configuration present in the 14B setup, because I couldn't find a relevant LoRA for the 1.3B.
  • Parameterized num_layers=40 in wan_utils instead of hardcoding the value so that varying model architectures can be supported while maintaining backwards compatibility. In particular, 1.3B has 30 layers, unlike 14B's 40 layers, which was causing a shape mismatch error.

Testing:

  • Tested on Google Colab v5e1 instance
View successful execution logs
/content/maxdiffusion
/usr/local/lib/python3.12/dist-packages/jax/_src/cloud_tpu_init.py:93: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled")
  warnings.warn(
WARNING:2026-02-03 22:20:38,221:jax._src.distributed:144: JAX detected proxy variable(s) in the environment as distributed setup: MODEL_PROXY_HOST COLAB_LANGUAGE_SERVER_PROXY_ROOT_URL COLAB_LANGUAGE_SERVER_PROXY_REQUEST_TIMEOUT COLAB_KERNEL_MANAGER_PROXY_HOST COLAB_LANGUAGE_SERVER_PROXY_LSP_DIRS COLAB_KERNEL_MANAGER_PROXY_PORT COLAB_LANGUAGE_SERVER_PROXY. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
W0203 22:20:38.221627 136016303665152 distributed.py:144] JAX detected proxy variable(s) in the environment as distributed setup: MODEL_PROXY_HOST COLAB_LANGUAGE_SERVER_PROXY_ROOT_URL COLAB_LANGUAGE_SERVER_PROXY_REQUEST_TIMEOUT COLAB_KERNEL_MANAGER_PROXY_HOST COLAB_LANGUAGE_SERVER_PROXY_LSP_DIRS COLAB_KERNEL_MANAGER_PROXY_PORT COLAB_LANGUAGE_SERVER_PROXY. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
INFO:2026-02-03 22:20:38,221:jax._src.distributed:149: Starting JAX distributed service on [::]:8482
I0203 22:20:38.221867 136016303665152 distributed.py:149] Starting JAX distributed service on [::]:8482
INFO:2026-02-03 22:20:38,223:jax._src.distributed:166: Connecting to JAX distributed service on localhost:8482
I0203 22:20:38.223598 136016303665152 distributed.py:166] Connecting to JAX distributed service on localhost:8482
Adding sequence sharding to q and kv if not already present because flash=='ring' or True is set.
Initial logical axis rules: [('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context')]
Adding sequence parallel attention axis rule ['activation_self_attn_heads', None]
Adding sequence parallel attention axis rule ['activation_self_attn_q_length', 'context']
Adding sequence parallel attention axis rule ['activation_self_attn_kv_length', None]
Adding sequence parallel attention axis rule ['activation_cross_attn_heads', None]
Adding sequence parallel attention axis rule ['activation_cross_attn_q_length', 'context']
Adding sequence parallel attention axis rule ['activation_cross_attn_kv_length', None]
Final logical axis rules: (['activation_self_attn_heads', None], ['activation_self_attn_q_length', 'context'], ['activation_self_attn_kv_length', None], ['activation_cross_attn_heads', None], ['activation_cross_attn_q_length', 'context'], ['activation_cross_attn_kv_length', None], ('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context'), ('activation_kv_length', 'context'))
Config param act_quantization_calibration_method: absmax
Config param activations_dtype: bfloat16
Config param adam_b1: 0.9
Config param adam_b2: 0.999
Config param adam_eps: 1e-08
Config param adam_weight_decay: 0
Config param allow_split_physical_axes: False
Config param attention: flash
Config param attention_sharding_uniform: True
Config param base_output_directory: 
Config param bwd_quantization_calibration_method: absmax
Config param cache_latents_text_encoder_outputs: True
Config param caption_column: text
Config param center_crop: False
Config param checkpoint_dir: "/content/output"/"wan1.3_dynamic"/checkpoints/
Config param checkpoint_every: -1
Config param compile_topology_num_slices: -1
Config param controlnet_conditioning_scale: 0.5
Config param controlnet_from_pt: True
Config param controlnet_image: https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png
Config param controlnet_model_name_or_path: diffusers/controlnet-canny-sdxl-1.0
Config param data_sharding: (('data', 'fsdp', 'context', 'tensor'),)
Config param dataset_config_name: 
Config param dataset_name: diffusers/pokemon-gpt4-captions
Config param dataset_save_location: 
Config param dataset_type: tfrecord
Config param dcn_context_parallelism: -1
Config param dcn_data_parallelism: 1
Config param dcn_fsdp_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param diffusion_scheduler_config: {'_class_name': 'FlaxEulerDiscreteScheduler', 'prediction_type': 'epsilon', 'rescale_zero_terminal_snr': False, 'timestep_spacing': 'trailing'}
Config param do_classifier_free_guidance: True
Config param dropout: 0.1
Config param enable_data_shuffling: True
Config param enable_eval_timesteps: False
Config param enable_generate_video_for_eval: False
Config param enable_jax_named_scopes: False
Config param enable_lora: False
Config param enable_mllog: False
Config param enable_profiler: False
Config param enable_single_replica_ckpt_restoring: False
Config param enable_ssim: False
Config param eval_data_dir: 
Config param eval_every: -1
Config param eval_max_number_of_samples_in_bucket: 60
Config param flash_block_sizes: {'block_q': 512, 'block_kv_compute': 512, 'block_kv': 512, 'block_q_dkv': 512, 'block_kv_dkv': 512, 'block_kv_dkv_compute': 512, 'block_q_dq': 512, 'block_kv_dq': 512, 'use_fused_bwd_kernel': False}
Config param flash_min_seq_length: 0
Config param flow_shift: 3.0
Config param fps: 16
Config param from_pt: True
Config param gcs_metrics: False
Config param global_batch_size: 0
Config param global_batch_size_to_load: 1
Config param global_batch_size_to_train_on: 1
Config param guidance_rescale: 0.0
Config param guidance_scale: 5.0
Config param hardware: tpu
Config param height: 480
Config param hf_access_token: None
Config param hf_data_dir: 
Config param hf_train_files: None
Config param ici_context_parallelism: 1
Config param ici_data_parallelism: -1
Config param ici_fsdp_parallelism: 1
Config param ici_tensor_parallelism: 1
Config param image_column: image
Config param jax_cache_dir: 
Config param jit_initializers: True
Config param learning_rate: 1e-05
Config param learning_rate_schedule_steps: 1500
Config param lightning_ckpt: 
Config param lightning_from_pt: True
Config param lightning_repo: 
Config param load_tfrecord_cached: True
Config param log_period: 100
Config param logical_axis_rules: (['activation_self_attn_heads', None], ['activation_self_attn_q_length', 'context'], ['activation_self_attn_kv_length', None], ['activation_cross_attn_heads', None], ['activation_cross_attn_q_length', 'context'], ['activation_cross_attn_kv_length', None], ('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context'), ('activation_kv_length', 'context'), ('layers_per_stage', None))
Config param lora_config: {'rank': [64], 'lora_model_name_or_path': [''], 'weight_name': [''], 'adapter_name': [''], 'scale': [1.0], 'from_pt': []}
Config param mask_padding_tokens: True
Config param max_grad_norm: 1.0
Config param max_train_samples: -1
Config param max_train_steps: 1500
Config param mesh_axes: ['data', 'fsdp', 'context', 'tensor']
Config param metrics_dir: "/content/output"/"wan1.3_dynamic"/metrics/
Config param metrics_file: 
Config param model_name: wan2.1
Config param model_type: T2V
Config param names_which_can_be_offloaded: []
Config param names_which_can_be_saved: []
Config param negative_prompt: Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
Config param no_records_per_shard: 0
Config param norm_num_groups: 32
Config param num_eval_samples: 420
Config param num_frames: 81
Config param num_inference_steps: 40
Config param num_slices: 1
Config param num_train_epochs: 1
Config param output_dir: "/content/output"
Config param per_device_batch_size: 1.0
Config param precision: DEFAULT
Config param pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
Config param profiler_steps: 10
Config param prompt: "A cinematic drone shot of a futuristic city at sunset, highly detailed, 8k"
Config param prompt_2: A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.
Config param quantization: 
Config param quantization_local_shard_count: 1
Config param qwix_module_path: .*
Config param random_flip: False
Config param remat_policy: NONE
Config param replicate_vae: False
Config param resolution: 1024
Config param reuse_example_batch: False
Config param revision: 
Config param run_name: "wan1.3_dynamic"
Config param save_config_to_gcs: False
Config param save_final_checkpoint: False
Config param save_optimizer: False
Config param scale_lr: False
Config param scan_layers: True
Config param seed: 0
Config param skip_first_n_steps_for_profiler: 5
Config param skip_jax_distributed_system: False
Config param snr_gamma: -1.0
Config param split_head_dim: True
Config param tensorboard_dir: "/content/output"/"wan1.3_dynamic"/tensorboard/
Config param text_encoder_learning_rate: 4.25e-06
Config param tfrecords_dir: 
Config param timestep_bias: {'strategy': 'none', 'multiplier': 1.0, 'begin': 0, 'end': 1000, 'portion': 0.25}
Config param timesteps_list: [125, 250, 375, 500, 625, 750, 875]
Config param timing_metrics_file: 
Config param tokenize_captions_num_proc: 4
Config param tokenizer_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
Config param total_train_batch_size: 1.0
Config param train_data_dir: 
Config param train_split: train
Config param train_text_encoder: False
Config param transform_images_num_proc: 4
Config param unet_checkpoint: 
Config param use_qwix_quantization: False
Config param wan_transformer_pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
Config param warmup_steps_fraction: 0.1
Config param weight_quantization_calibration_method: absmax
Config param weights_dtype: bfloat16
Config param width: 832
Config param write_metrics: True
Config param write_timing_metrics: True
TensorBoard logs will be written to: "/content/output"/"wan1.3_dynamic"/tensorboard/
Git Commit Hash: f23746bfc25d4d6159403d17c66856ba82cf64b7
Creating checkpoing manager...
checkpoint dir: "/content/output"/"wan1.3_dynamic"/checkpoints/
item_names: ('low_noise_transformer_state', 'high_noise_transformer_state', 'wan_state', 'wan_config')
I0203 22:20:39.935555 136016303665152 checkpoint_manager.py:709] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=('low_noise_transformer_state', 'high_noise_transformer_state', 'wan_state', 'wan_config'), item_handlers=None, handler_registry=None
I0203 22:20:39.935851 136016303665152 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler `<orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bb476920350>` for item "metrics" and save args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>` and restore args `<class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>` to `_handler_registry`.
I0203 22:20:39.935948 136016303665152 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bb476920350>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7bb476920350>}).
I0203 22:20:39.936277 136016303665152 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.32
I0203 22:20:39.936344 136016303665152 async_checkpointer.py:177] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>.<lambda> at 0x7bb3c1776840> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0203 22:20:39.936712 136016303665152 checkpoint_manager.py:1818] Found 0 checkpoint steps in "/content/output"/"wan1.3_dynamic"/checkpoints
I0203 22:20:39.936876 136016303665152 checkpoint_manager.py:929] [process=0][thread=MainThread] CheckpointManager created,  primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False), root_directory="/content/output"/"wan1.3_dynamic"/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7bb4768fffe0>
Checkpoint manager created!
Latest WAN checkpoint step: None
No WAN checkpoint found.
No checkpoint found, loading default pipeline.
Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)] (num_devices: 1)
Decided on mesh: [[[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]]]
/content/maxdiffusion/src/maxdiffusion/configuration_utils.py:262: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
/usr/local/lib/python3.12/dist-packages/huggingface_hub/file_download.py:942: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Load and port Wan-AI/Wan2.1-T2V-1.3B-Diffusers VAE on TFRT_CPU_0
Downloading shards: 100% 5/5 [00:00<00:00, 14324.81it/s]
Loading checkpoint shards: 100% 5/5 [00:01<00:00,  3.18it/s]
The config attributes {'rescale_betas_zero_snr': False} were passed to FlaxUniPCMultistepScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
Load and port Wan-AI/Wan2.1-T2V-1.3B-Diffusers transformer on TFRT_CPU_0
Load and port Wan-AI/Wan2.1-T2V-1.3B-Diffusers transformer on TFRT_CPU_0
is_norm_path:  (DictKey(key='blocks'), DictKey(key='adaln_scale_shift_table'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='attn1'), DictKey(key='norm_k'), DictKey(key='scale'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='attn1'), DictKey(key='norm_q'), DictKey(key='scale'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='attn2'), DictKey(key='norm_k'), DictKey(key='scale'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='attn2'), DictKey(key='norm_q'), DictKey(key='scale'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='norm2'), DictKey(key='layer_norm'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='blocks'), DictKey(key='norm2'), DictKey(key='layer_norm'), DictKey(key='scale'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='text_embedder'), DictKey(key='linear_1'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='text_embedder'), DictKey(key='linear_1'), DictKey(key='kernel'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='text_embedder'), DictKey(key='linear_2'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='text_embedder'), DictKey(key='linear_2'), DictKey(key='kernel'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_embedder'), DictKey(key='linear_1'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_embedder'), DictKey(key='linear_1'), DictKey(key='kernel'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_embedder'), DictKey(key='linear_2'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_embedder'), DictKey(key='linear_2'), DictKey(key='kernel'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_proj'), DictKey(key='bias'))
is_norm_path:  (DictKey(key='condition_embedder'), DictKey(key='time_proj'), DictKey(key='kernel'))
is_norm_path:  (DictKey(key='scale_shift_table'),)
Num steps: 40, height: 480, width: 832, frames: 81
/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/array_methods.py:125: UserWarning: Explicitly requested dtype int64 requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
===================== Model details =======================
model name: wan2.1
model path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model type: T2V
hardware: tpu
number of devices: 1
per_device_batch_size: 1.0
============================================================
compile_time: 409.6654962049997
/usr/lib/python3.12/subprocess.py:1885: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _fork_exec(
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
generation_time: 349.6298596889992
generation time per video: 349.6298596889992

@rxu183 rxu183 requested a review from entrpn as a code owner February 3, 2026 23:18
@entrpn entrpn merged commit ced76d0 into AI-Hypercomputer:main Feb 4, 2026
4 checks passed
@entrpn
Copy link
Collaborator

entrpn commented Feb 4, 2026

merged. Thank you.

@rxu183 rxu183 deleted the richard/1_3b branch February 4, 2026 05:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants