From b409d3af51ab460be58f4ca80ca752df35f629fb Mon Sep 17 00:00:00 2001 From: Richard <99752583+rxu183@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:42:08 -0600 Subject: [PATCH] Add WAN 1.3B Config and parameterize num_layers --- src/maxdiffusion/configs/base_wan_1_3b.yml | 332 +++++++++++++++++++++ src/maxdiffusion/models/wan/wan_utils.py | 16 +- 2 files changed, 343 insertions(+), 5 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_1_3b.yml diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml new file mode 100644 index 00000000..ffd2864a --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -0,0 +1,332 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers' +model_name: wan2.1 +model_type: 'T2V' + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 0 + +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True +dropout: 0.1 + +flash_block_sizes: { + "block_q" : 512, + "block_kv_compute" : 512, + "block_kv" : 512, + "block_q_dkv" : 512, + "block_kv_dkv" : 512, + "block_kv_dkv_compute" : 512, + "block_q_dq" : 512, + "block_kv_dq" : 512, + "use_fused_bwd_kernel": False, +} +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed', ['context', 'fsdp']], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data', 'context', 'fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'context'], + ] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +guidance_scale: 5.0 +flow_shift: 3.0 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 16 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +enable_lora: False +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + rank: [64], + lora_model_name_or_path: [""], + weight_name: [""], + adapter_name: [""], + scale: [1.0], + from_pt: [] +} + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index b12d1907..aa731923 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -72,7 +72,7 @@ def rename_for_custom_trasformer(key): return renamed_pt_key -def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers): +def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=40): if scan_layers: if "blocks" in pt_tuple_key: new_key = ("blocks",) + pt_tuple_key[2:] @@ -89,7 +89,7 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) return flax_key, flax_tensor @@ -127,7 +127,9 @@ def load_fusionx_transformer( pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) + flax_key, flax_tensor = get_key_and_value( + pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers + ) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -167,7 +169,9 @@ def load_causvid_transformer( renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) + flax_key, flax_tensor = get_key_and_value( + pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers + ) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -284,7 +288,9 @@ def load_base_wan_transformer( renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) + flax_key, flax_tensor = get_key_and_value( + pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers + ) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict)