diff --git a/diffsynth/core/gradient/gradient_checkpoint.py b/diffsynth/core/gradient/gradient_checkpoint.py index b356415a0..d252573a2 100644 --- a/diffsynth/core/gradient/gradient_checkpoint.py +++ b/diffsynth/core/gradient/gradient_checkpoint.py @@ -21,6 +21,7 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, + determinism_check="none" ) elif use_gradient_checkpointing: model_output = torch.utils.checkpoint.checkpoint( @@ -28,6 +29,7 @@ def gradient_checkpoint_forward( *args, **kwargs, use_reentrant=False, + determinism_check="none" ) else: model_output = model(*args, **kwargs) diff --git a/diffsynth/core/loader/model.py b/diffsynth/core/loader/model.py index 56fa7d362..5d9b05275 100644 --- a/diffsynth/core/loader/model.py +++ b/diffsynth/core/loader/model.py @@ -3,21 +3,24 @@ from ..vram.layers import enable_vram_management from .file import load_state_dict import torch +from contextlib import contextmanager +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import ContextManagers -def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): +def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, + use_disk_map=False, module_map=None, vram_config=None, vram_limit=None): config = {} if config is None else config - # Why do we use `skip_model_initialization`? - # It skips the random initialization of model parameters, - # thereby speeding up model loading and avoiding excessive memory usage. - with skip_model_initialization(): + with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)): model = model_class(**config) # What is `module_map`? # This is a module mapping table for VRAM management. if module_map is not None: - devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]] + devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], + vram_config["computation_device"]] device = [d for d in devices if d != "disk"][0] - dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]] + dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], + vram_config["computation_dtype"]] dtype = [d for d in dtypes if d != "disk"][0] if vram_config["offload_device"] != "disk": state_dict = DiskMap(path, device, torch_dtype=dtype) @@ -26,10 +29,12 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic else: state_dict = {i: state_dict[i] for i in state_dict} model.load_state_dict(state_dict, assign=True) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, + vram_limit=vram_limit) else: disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter) - model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit) + model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, + vram_limit=vram_limit) else: # Why do we use `DiskMap`? # Sometimes a model file contains multiple models, @@ -46,7 +51,11 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic state_dict = state_dict_converter(state_dict) else: state_dict = {i: state_dict[i] for i in state_dict} - model.load_state_dict(state_dict, assign=True) + if is_deepspeed_zero3_enabled(): + from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model + _load_state_dict_into_zero3_model(model, state_dict) + else: + model.load_state_dict(state_dict, assign=True) # Why do we call `to()`? # Because some models override the behavior of `to()`, # especially those from libraries like Transformers. @@ -56,7 +65,8 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic return model -def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None): +def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", + state_dict_converter=None, module_map=None): if isinstance(path, str): path = [path] config = {} if config is None else config @@ -77,3 +87,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor } enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80) return model + + +def get_init_context(torch_dtype, device): + if is_deepspeed_zero3_enabled(): + from transformers.modeling_utils import set_zero3_state + import deepspeed + # Why do we use "deepspeed.zero.Init"? + # Weight segmentation of the model can be performed on the CPU side + # and loading the segmented weights onto the computing card + init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()] + else: + # Why do we use `skip_model_initialization`? + # It skips the random initialization of model parameters, + # thereby speeding up model loading and avoiding excessive memory usage. + init_contexts = [skip_model_initialization()] + + return init_contexts diff --git a/diffsynth/diffusion/logger.py b/diffsynth/diffusion/logger.py index 6d2792f1b..ab6bdb9be 100644 --- a/diffsynth/diffusion/logger.py +++ b/diffsynth/diffusion/logger.py @@ -18,8 +18,8 @@ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_ste def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) @@ -34,8 +34,8 @@ def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name): accelerator.wait_for_everyone() + state_dict = accelerator.get_state_dict(model) if accelerator.is_main_process: - state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index f6e2263ae..6e26035e8 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -27,7 +27,7 @@ def launch_training_task( optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) - + model.to(device=accelerator.device) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): @@ -59,6 +59,7 @@ def launch_data_process_task( num_workers = args.dataset_num_workers dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + model.to(device=accelerator.device) model, dataloader = accelerator.prepare(model, dataloader) for data_id, data in enumerate(tqdm(dataloader)): diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py index 3ace70d8b..8873affb2 100644 --- a/diffsynth/models/wan_video_animate_adapter.py +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -607,7 +607,7 @@ def __init__(self, size, style_dim=512, motion_dim=20): def get_motion(self, img): #motion_feat = self.enc.enc_motion(img) - motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none") motion = self.dec.direction(motion_feat) return motion diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 738622302..c11a1bc02 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -5,6 +5,7 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter +from ..core.gradient import gradient_checkpoint_forward try: import flash_attn_interface @@ -379,27 +380,15 @@ def forward(self, self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 8fbed8c05..f4d1abe4c 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from typing import Tuple from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d +from ..core.gradient import gradient_checkpoint_forward def torch_dfs(model: nn.Module, parent_name='root'): @@ -545,46 +546,19 @@ def forward( t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block_id, block in enumerate(self.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - context, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] + ) + x = gradient_checkpoint_forward( + lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) x = x[:, :seq_len_x] x = self.head(x, t[:-1]) diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py index f3367f788..0e13183c2 100644 --- a/diffsynth/models/wan_video_vace.py +++ b/diffsynth/models/wan_video_vace.py @@ -1,6 +1,6 @@ import torch from .wan_video_dit import DiTBlock - +from ..core.gradient import gradient_checkpoint_forward class VaceWanAttentionBlock(DiTBlock): def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): @@ -62,26 +62,13 @@ def forward( dim=1) for u in c ]) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - for block in self.vace_blocks: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - c = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - c, x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - c = block(c, x, context, t_mod, freqs) + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs + ) + hints = torch.unbind(c)[:-1] return hints diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d24e29d93..e43e6a8cd 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -171,7 +171,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 - return x + return x, feat_cache, feat_idx def init_weight(self, conv): conv_weight = conv.weight @@ -298,7 +298,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + h, feat_cache, feat_idx class AttentionBlock(nn.Module): @@ -471,7 +471,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): for module in self.downsamples: x = module(x, feat_cache, feat_idx) - return x + self.avg_shortcut(x_copy) + return x + self.avg_shortcut(x_copy), feat_cache, feat_idx class Up_ResidualBlock(nn.Module): @@ -511,7 +511,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x_shortcut = self.avg_shortcut(x, first_chunk) return x_main + x_shortcut else: - return x_main + return x_main, feat_cache, feat_idx class Encoder3d(nn.Module): @@ -586,14 +586,14 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -614,7 +614,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx class Encoder3d_38(nn.Module): @@ -698,14 +698,14 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## middle for layer in self.middle: if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -730,7 +730,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = layer(x) - return x + return x, feat_cache, feat_idx class Decoder3d(nn.Module): @@ -807,14 +807,14 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## middle for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -835,7 +835,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx @@ -906,14 +906,14 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): for layer in self.middle: if check_is_instance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) else: x = layer(x) ## upsamples for layer in self.upsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx, first_chunk) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx, first_chunk) else: x = layer(x) @@ -937,7 +937,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): feat_idx[0] += 1 else: x = layer(x) - return x + return x, feat_cache, feat_idx def count_conv3d(model): @@ -990,11 +990,11 @@ def encode(self, x, scale): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :, :1, :, :], + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index d5dc35bd4..c68dcb9ae 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -348,7 +348,7 @@ def get_qwen3_prompt_embeds( attention_mask = torch.cat(all_attention_masks, dim=0).to(device) # Forward pass through the model - with torch.inference_mode(): + with torch.no_grad(): output = text_encoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index edd6dffc3..f5687010d 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1321,11 +1321,6 @@ def model_fn_wan_video( if tea_cache_update: x = tea_cache.update(x) else: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - def create_custom_forward_vap(block, vap): def custom_forward(*inputs): return vap(block, *inputs) @@ -1340,31 +1335,25 @@ def custom_forward(*inputs): create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, use_reentrant=False, + determinism_check="none" ) elif use_gradient_checkpointing: x, x_vap = torch.utils.checkpoint.checkpoint( create_custom_forward_vap(block, vap), x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id, use_reentrant=False, + determinism_check="none" ) else: x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id) else: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, freqs) + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) + # VACE if vace_context is not None and block_id in vace.vace_layers_mapping: @@ -1487,32 +1476,18 @@ def custom_forward(*inputs): return custom_forward for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, seq_len_x, pre_compute_freqs[0] ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + x = gradient_checkpoint_forward( + lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x), + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x + ) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b33c..cea55e479 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -6,6 +6,7 @@ get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ...core.device import parse_nccl_backend, parse_device_type +from ...core.gradient import gradient_checkpoint_forward def initialize_usp(device_type): @@ -81,11 +82,6 @@ def usp_dit_forward(self, self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward # Context Parallel chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) @@ -94,20 +90,13 @@ def custom_forward(*inputs): x = chunks[get_sequence_parallel_rank()] for block in self.blocks: - if self.training and use_gradient_checkpointing: - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) - else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, freqs, - use_reentrant=False, - ) + if self.training: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs + ) else: x = block(x, context, t_mod, freqs) diff --git a/examples/flux/model_training/full/accelerate_config_zero3.yaml b/examples/flux/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 000000000..e6a8d2733 --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/full/accelerate_config_zero3.yaml b/examples/flux2/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 000000000..e6a8d2733 --- /dev/null +++ b/examples/flux2/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh new file mode 100644 index 000000000..ed678f251 --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-dev-Lora-NPU.sh @@ -0,0 +1,36 @@ +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors,black-forest-labs/FLUX.2-dev:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "sft:data_process" + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config_zero3.yaml examples/flux2/model_training/train.py \ + --dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-dev:transformer/*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-dev-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --initialize_model_on_cpu \ + --task "sft:train" diff --git a/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh new file mode 100644 index 000000000..57755acf7 --- /dev/null +++ b/examples/flux2/model_training/special/npu_training/FLUX.2-klein-9B-NPU.sh @@ -0,0 +1,34 @@ +# This script is tested on 8*910B(NPU) +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ + --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.2-klein-9B_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing + +# Edit +# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-9B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index ea727b8e1..6101687db 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -85,6 +85,7 @@ def flux2_parser(): parser = add_general_config(parser) parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") return parser @@ -126,7 +127,7 @@ def flux2_parser(): fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 000000000..e6a8d2733 --- /dev/null +++ b/examples/qwen_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh new file mode 100644 index 000000000..48922283d --- /dev/null +++ b/examples/qwen_image/model_training/special/npu_training/Qwen-Image-Edit-2509-NPU.sh @@ -0,0 +1,19 @@ +# This script was tested using zero3 and on 8*910B(NPU) +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CPU_AFFINITY_CONF=1 + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero3.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image-Edit-2509:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Edit-2509_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --find_unused_parameters diff --git a/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 000000000..e6a8d2733 --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/z_image/model_training/full/accelerate_config_zero3.yaml b/examples/z_image/model_training/full/accelerate_config_zero3.yaml new file mode 100644 index 000000000..e6a8d2733 --- /dev/null +++ b/examples/z_image/model_training/full/accelerate_config_zero3.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false