Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth/core/gradient/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
else:
model_output = model(*args, **kwargs)
Expand Down
49 changes: 38 additions & 11 deletions diffsynth/core/loader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
from ..vram.layers import enable_vram_management
from .file import load_state_dict
import torch
from contextlib import contextmanager
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import ContextManagers


def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None,
use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
config = {} if config is None else config
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
with skip_model_initialization():
with ContextManagers(get_init_context(torch_dtype=torch_dtype, device=device)):
model = model_class(**config)
# What is `module_map`?
# This is a module mapping table for VRAM management.
if module_map is not None:
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"],
vram_config["computation_device"]]
device = [d for d in devices if d != "disk"][0]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"],
vram_config["computation_dtype"]]
dtype = [d for d in dtypes if d != "disk"][0]
if vram_config["offload_device"] != "disk":
state_dict = DiskMap(path, device, torch_dtype=dtype)
Expand All @@ -26,10 +29,12 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None,
vram_limit=vram_limit)
else:
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map,
vram_limit=vram_limit)
else:
# Why do we use `DiskMap`?
# Sometimes a model file contains multiple models,
Expand All @@ -46,7 +51,11 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
state_dict = state_dict_converter(state_dict)
else:
state_dict = {i: state_dict[i] for i in state_dict}
model.load_state_dict(state_dict, assign=True)
if is_deepspeed_zero3_enabled():
from transformers.integrations.deepspeed import _load_state_dict_into_zero3_model
_load_state_dict_into_zero3_model(model, state_dict)
else:
model.load_state_dict(state_dict, assign=True)
# Why do we call `to()`?
# Because some models override the behavior of `to()`,
# especially those from libraries like Transformers.
Expand All @@ -56,7 +65,8 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
return model


def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu",
state_dict_converter=None, module_map=None):
if isinstance(path, str):
path = [path]
config = {} if config is None else config
Expand All @@ -77,3 +87,20 @@ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=tor
}
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
return model


def get_init_context(torch_dtype, device):
if is_deepspeed_zero3_enabled():
from transformers.modeling_utils import set_zero3_state
import deepspeed
# Why do we use "deepspeed.zero.Init"?
# Weight segmentation of the model can be performed on the CPU side
# and loading the segmented weights onto the computing card
init_contexts = [deepspeed.zero.Init(remote_device=device, dtype=torch_dtype), set_zero3_state()]
else:
# Why do we use `skip_model_initialization`?
# It skips the random initialization of model parameters,
# thereby speeding up model loading and avoiding excessive memory usage.
init_contexts = [skip_model_initialization()]

return init_contexts
4 changes: 2 additions & 2 deletions diffsynth/diffusion/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_ste

def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
Expand All @@ -34,8 +34,8 @@ def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save

def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def launch_training_task(
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)

model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)

for epoch_id in range(num_epochs):
Expand Down Expand Up @@ -59,6 +59,7 @@ def launch_data_process_task(
num_workers = args.dataset_num_workers

dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)

for data_id, data in enumerate(tqdm(dataloader)):
Expand Down
2 changes: 1 addition & 1 deletion diffsynth/models/wan_video_animate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def __init__(self, size, style_dim=512, motion_dim=20):

def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True, determinism_check="none")
motion = self.dec.direction(motion_feat)
return motion

Expand Down
27 changes: 8 additions & 19 deletions diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward

try:
import flash_attn_interface
Expand Down Expand Up @@ -379,27 +380,15 @@ def forward(self,
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
if self.training:
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, freqs
)
else:
x = block(x, context, t_mod, freqs)

Expand Down
52 changes: 13 additions & 39 deletions diffsynth/models/wan_video_dit_s2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn.functional as F
from typing import Tuple
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
from ..core.gradient import gradient_checkpoint_forward


def torch_dfs(model: nn.Module, parent_name='root'):
Expand Down Expand Up @@ -545,46 +546,19 @@ def forward(
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2)

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

for block_id, block in enumerate(self.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
context,
t_mod,
seq_len_x,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
x,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x, context, t_mod, seq_len_x, pre_compute_freqs[0]
)
x = gradient_checkpoint_forward(
lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x),
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
x
)

x = x[:, :seq_len_x]
x = self.head(x, t[:-1])
Expand Down
29 changes: 8 additions & 21 deletions diffsynth/models/wan_video_vace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from .wan_video_dit import DiTBlock

from ..core.gradient import gradient_checkpoint_forward

class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
Expand Down Expand Up @@ -62,26 +62,13 @@ def forward(
dim=1) for u in c
])

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

for block in self.vace_blocks:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
elif use_gradient_checkpointing:
c = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
c, x, context, t_mod, freqs,
use_reentrant=False,
)
else:
c = block(c, x, context, t_mod, freqs)
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)

hints = torch.unbind(c)[:-1]
return hints
Loading