diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index da443c4f6..16f7d92c8 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -15,7 +15,6 @@ import torch import torch.monitor -from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, ReduceOp, @@ -29,6 +28,15 @@ logger = logging.getLogger(__name__) +def _get_device(group: ProcessGroup) -> torch.device: + if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL): + return torch.device(torch.cuda.current_device()) + elif isinstance(group, torch.distributed.ProcessGroupGloo): + return torch.device("cpu") + else: + raise NotImplementedError(type(group)) + + @contextlib.contextmanager def set_timeout(group: ProcessGroup | None, timeout: float | None = None): if group is not None and timeout is not None: @@ -42,7 +50,7 @@ def set_timeout(group: ProcessGroup | None, timeout: float | None = None): def broadcast( tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None -) -> Work | None: +) -> torch.distributed.Work | None: """Same as torch.distributed.broadcast, but without the complication of going through the global rank.""" assert group is not None opts = torch.distributed.BroadcastOptions() @@ -72,12 +80,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier( - group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None -) -> None: +def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -88,10 +94,9 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, - device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() @@ -106,7 +111,7 @@ def all_gather_scalar( timeout: float | None = None, ): if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) output_tensor = value.new_empty((group.size(),)) with set_timeout(group, timeout): torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) @@ -116,7 +121,7 @@ def all_gather_scalar( def broadcast_scalar( - value: float | int, + value: float | int | None, dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, src: int = 0, @@ -124,7 +129,7 @@ def broadcast_scalar( ) -> float | int: if not group: return value - tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device())) + tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group))) if group.rank() == src: tensor.fill_(value) broadcast(tensor, src, group, timeout=timeout) @@ -141,19 +146,21 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None if group.rank() == src: tensor = _object_to_tensor(input_object) size = tensor.numel() - broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast_tensor.copy_(tensor) broadcast_scalar(size, torch.int64, group, src) broadcast(broadcast_tensor, src, group) return input_object else: size = int(broadcast_scalar(None, torch.int64, group, src)) - output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast(output_tensor, src, group) return _tensor_to_object(output_tensor) -def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: +def send( + tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0 +) -> torch.distributed.Work | None: assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # send not supported for gloo on GPU. @@ -169,7 +176,9 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta return None -def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: +def recv( + tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0 +) -> torch.distributed.Work | None: assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # recv not supported for gloo on GPU. diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index bb61aadd0..7d361a22e 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -8,7 +8,6 @@ import torch import torch._dynamo # noqa import torch.autograd -from torch._C._distributed_c10d import Work from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from fast_llm.utils import Assert, div @@ -18,7 +17,7 @@ def reduce_op( input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: if group: handle = all_reduce(input_, group=group, async_op=async_op, op=op) else: @@ -62,7 +61,7 @@ def swap_mult_dim(tensor: torch.Tensor, factor: int, old_dim: int, new_dim: int) def gather_op( input_: torch.Tensor, group: ProcessGroup | None, dim: int, async_op: bool = False, out=None -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: """Gather tensors and concatenate along the last dimension.""" # Bypass the function if we are using only 1 GPU. if not group: @@ -89,7 +88,7 @@ def reduce_scatter_op( op: ReduceOp = ReduceOp.SUM, dim: int = 0, async_op: bool = False, -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: """Reduce-scatter the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if not group: diff --git a/fast_llm/data/data_loader.py b/fast_llm/data/data/data_loader.py similarity index 100% rename from fast_llm/data/data_loader.py rename to fast_llm/data/data/data_loader.py diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index e572e8e61..17f151919 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,8 +8,8 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader import SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.data_loader import SampledDatasetIterator from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 157744f51..4408ca772 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: import numpy as np import torch + import transformers @config_class(dynamic_type={PreprocessingConfig: "tokenizer"}) @@ -52,7 +53,7 @@ def __init__(self, config: ConfigType): from transformers import AutoTokenizer log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: "transformers.PreTrainedTokenizer" = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=self._config.path, errors="replace", max_len=None, @@ -70,10 +71,15 @@ def __init__(self, config: ConfigType): @functools.cached_property def vocab_size(self) -> int: - out = len(self.tokenizer) - if self._config.max_vocab_size is not None: - out = min(out, self._config.max_vocab_size) - return out + return ( + self._tokenizer_vocab_size + if self._config.max_vocab_size is None + else min(self._tokenizer_vocab_size, self._config.max_vocab_size) + ) + + @functools.cached_property + def _tokenizer_vocab_size(self) -> int: + return len(self.tokenizer) @property def vocab(self) -> dict[str, int]: @@ -99,7 +105,11 @@ def tokenize( tokens = ( torch.tensor( tokens, - dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + dtype=( + torch.int64 + if self._tokenizer_vocab_size > torch.iinfo(data_type.torch).max + else data_type.torch + ), ) % self._config.max_vocab_size ).to(data_type.torch) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 7ae537104..32ea60cb8 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -85,7 +85,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return PatchSample( + return self.__class__( self.patches.new_empty((0, *self.patches.shape[1:])), self.token_map.new_empty(0), self.positions.new_empty([0, self.patches.ndim - 2]), diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 53683342a..f57ee04d9 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -52,7 +52,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return RangeSample([], size) + return self.__class__([], size) class RangeBatch(Batch): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index cd4d7fa02..6ab55dbba 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -58,7 +58,7 @@ def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) class TokenBatch(Batch): diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 3f1970538..98303539e 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -141,11 +141,12 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): + _abstract = False model_weights: bool = FieldUpdate(desc="Save the model weights.") optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.") def _validate(self) -> None: - if self.optimizer_state is None: + if self.optimizer_state is None and hasattr(self.format, "support_optimizer"): with self._set_implicit_default(): # TODO: Make sure it's a type self.optimizer_state = self.format.support_optimizer diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index ae37410ae..f84f36309 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -4,7 +4,6 @@ import typing import torch -from torch._C._distributed_c10d import ReduceOp from torch.distributed import all_reduce, reduce_scatter_tensor from fast_llm.core.distributed import ProcessGroup @@ -398,7 +397,7 @@ def reduce_gradients( out, self._grad_buffer, group=self._fsdp_group, - op=ReduceOp.AVG, + op=torch.distributed.ReduceOp.AVG, ) if accumulate: triton_add(self._grad_shard, out, self._grad_shard) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 698f62daa..ed293b103 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -4,7 +4,6 @@ import warnings import torch -from torch._C._distributed_c10d import ProcessGroup from fast_llm.config import Configurable from fast_llm.engine.base_model.base_model import BaseModel @@ -611,7 +610,7 @@ class TiedParameter: # Whether the local rank is involved at all. on_device: bool # Process group for reduction. - group: ProcessGroup | None = dataclasses.field(repr=False, init=False) + group: torch.distributed.ProcessGroup | None = dataclasses.field(repr=False, init=False) all_ranks: set[int] # The index of the main stage. main_stage: int diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index f1212f4b8..d56c745ae 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -1,31 +1,42 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.utils import Assert -def _torch_entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, +@torch.compile +def torch_entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor | None]: # (), (*batch, vocab) """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for with labels format and if the masking index is set to -100. """ + + # Torch methods require flattened batch dimension. + target = target.flatten() if target_format == TargetFormat.labels else target.flatten(0, -2) + if target_format == TargetFormat.labels: + assert loss_mask is None + loss_mask = target >= 0 + else: + target = target.float() + if loss_mask is not None: + loss_mask = loss_mask.flatten() + # Torch compile doesn't understand this. with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) - logits_scaled = logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor + + logits_scaled = (logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor).flatten(0, -2) if target_format == TargetFormat.logits: target_scale = logits_scale_factor / temperature target = target if target_scale == 1.0 else target * target_scale @@ -35,9 +46,7 @@ def _torch_entropy_loss_forward_backward( if entropy_loss_type == EntropyLossType.cross_entropy: if target_format == TargetFormat.logits: target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy( - logits_scaled, target, reduction="mean" if loss_mask is None else "none" - ) + per_sample_loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction="none") else: predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) if target_format == TargetFormat.logits: @@ -45,30 +54,33 @@ def _torch_entropy_loss_forward_backward( elif target_format == TargetFormat.probabilities: target_log_probability = target.log() else: - target_log_probability = ( - torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).add(1.0e-10).log() + target_probability = torch.nn.functional.one_hot( + torch.clamp_min(target, 0), num_classes=logits_scaled.size(-1) ) + if loss_mask is not None: + target_probability = target_probability * loss_mask.unsqueeze(-1) + target_log_probability = target_probability.add(1.0e-10).log() if entropy_loss_type == EntropyLossType.forward_kl: - loss = torch.nn.functional.kl_div( + per_sample_loss = torch.nn.functional.kl_div( predicted_log_probability, target_log_probability, - reduction="batchmean" if loss_mask is None else "none", + reduction="none", log_target=True, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - loss = torch.nn.functional.kl_div( + per_sample_loss = torch.nn.functional.kl_div( target_log_probability, predicted_log_probability, - reduction="batchmean" if loss_mask is None else "none", + reduction="none", log_target=True, ) else: raise NotImplementedError(entropy_loss_type) - if loss_mask is not None: - loss = loss.sum(dim=-1) + per_sample_loss = per_sample_loss.sum(dim=-1) if loss_mask is not None: - loss = (loss * loss_mask).mean() + per_sample_loss = per_sample_loss * loss_mask + loss = per_sample_loss.mean() if grad_output is None: grad = None @@ -79,49 +91,61 @@ def _torch_entropy_loss_forward_backward( @torch.compile -def _fused_softmax_base( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def fused_softmax_base( + logits: torch.Tensor, # (*batch, vocab) + logits_scale_factor: float = 1.0, + group: ProcessGroup | None = None, + dim: int = -1, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: # (*batch, vocab), (*batch, vocab), (*batch,), (*batch,) + """ + Calculate the required inputs for softmax computation, mainly sum_exp_logits, + in a numerically stable way and with tensor-parallel support. + Warning: The returned values are regularized by `logits_max`. + The regularization typically but not always cancels out in derived quantities. + """ logits = logits.float() if logits_scale_factor != 1.0: logits = logits * logits_scale_factor - logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + logits_max = logits.max(dim=dim)[0] if group is not None: all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = (logits - logits_max).float() + logits_norm = (logits - logits_max.unsqueeze(-1)).float() exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + sum_exp_logits = exp_logits.sum(dim=dim) if group is not None: all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - return logits_norm, exp_logits, sum_exp_logits + return logits_norm, exp_logits, sum_exp_logits, logits_max @torch.compile def _fused_reverse_kl_base( - logits: torch.Tensor, - target: torch.Tensor, + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch, vocab) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, temperature: float = 1.0, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - predicted_log_probability = logits_norm - sum_exp_logits.log() - predicted_probability = exp_logits / sum_exp_logits +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + assert target_format in (TargetFormat.logits, TargetFormat.probabilities) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_log_probability = logits_norm - sum_exp_logits.log().unsqueeze(-1) + predicted_probability = exp_logits / sum_exp_logits.unsqueeze(-1) if target_format == TargetFormat.logits: - target_logits_norm, _, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, _, sum_exp_target_logits, _ = fused_softmax_base( target, logits_scale_factor / temperature, group ) - target_log_probability = target_logits_norm - sum_exp_target_logits.log() + target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) else: target_log_probability = torch.log(target) # Compute loss terms: student_probs * log_ratio, then sum over vocab # This is equivalent to kl_div(..., log_target=True) but more memory efficient log_ratio = predicted_log_probability - target_log_probability - per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True) + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) if group is not None: all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) @@ -130,39 +154,39 @@ def _fused_reverse_kl_base( else: # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) # where E_q[log(q/p)] is the expected log ratio under the student distribution - grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output return per_sample_loss, grad @torch.compile def _fused_cross_entropy_base( - logits: torch.Tensor, - target: torch.Tensor, + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch, vocab) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, temperature: float = 1.0, return_kl_loss: bool = False, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target_logits_norm, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, exp_logits_targets, sum_exp_target_logits, _ = fused_softmax_base( target, logits_scale_factor / temperature, group ) - target = exp_logits_targets / sum_exp_target_logits + target = exp_logits_targets / sum_exp_target_logits.unsqueeze(-1) # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) if return_kl_loss: if target_format == TargetFormat.logits: - target_log_probability = target_logits_norm - sum_exp_target_logits.log() + target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) else: target_log_probability = torch.log(target) logits_norm = logits_norm - target_log_probability - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + predicted_logits = (target * logits_norm).sum(dim=-1) if group is not None: # We need to sum the over the tensor-parallel group, # but this is handled in the final averaging provided we multiply by the group size. @@ -174,41 +198,63 @@ def _fused_cross_entropy_base( grad = None else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - grad = (exp_logits - sum_exp_logits * target) * (grad_output / sum_exp_logits) + grad = (exp_logits - sum_exp_logits.unsqueeze(-1) * target) * (grad_output / sum_exp_logits.unsqueeze(-1)) return per_sample_loss, grad @torch.compile -def _fused_cross_entropy_base_from_labels( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor, - grad_output: float | None, - logits_scale_factor: float, +def fused_predicted_logits_from_labels( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + loss_mask: torch.Tensor, # (*batch,), == target>=0 group: ProcessGroup | None = None, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # (*batch,), (*batch,), (*batch,) + """ + Recover the value of the logits at the target index, with support for masking (target < 0) and tensor parallelism. + In the simple case, equivalent to `logits.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)` - target = target.unsqueeze(-1) + Normally used in combination with `fused_softmax_base`, may also recover probabilities or log probabilities: + `predicted_probabilities = predicted_logits.exp() / sum_exp_logits` + `predicted_log_probabilities = predicted_logits / sum_exp_logits.log()` + """ if group is None: # Keep values within range for scatter and gather ops to work. - target = target * loss_mask.unsqueeze(-1) + target_masked = target * loss_mask target_mask = None else: - # Mask the target (fused) + # Mask the target (fused). # TODO: Could mask earlier on cpu or overlap with reduce? vocab_start_index = logits.size(-1) * group.rank() target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + target_masked = (target - vocab_start_index) * target_mask # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss is the same because P * log(P) == 0. - predicted_logits = logits_norm.gather(1, target) + predicted_logits = logits.gather(-1, target_masked.unsqueeze(-1)).squeeze(-1) if group is not None: predicted_logits = target_mask * predicted_logits all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + return predicted_logits, target_masked, target_mask + + +@torch.compile +def _fused_cross_entropy_base_from_labels( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + loss_mask: torch.Tensor, # (*batch,) + grad_output: float | None, + logits_scale_factor: float, + group: ProcessGroup | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss is the same because P * log(P) == 0. per_sample_loss = sum_exp_logits.log() - predicted_logits if grad_output is None: @@ -216,17 +262,19 @@ def _fused_cross_entropy_base_from_labels( else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. grad = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) - ) * (grad_output / sum_exp_logits) + -1, + target_masked.unsqueeze(-1), + -sum_exp_logits.unsqueeze(-1) if target_mask is None else -(target_mask * sum_exp_logits).unsqueeze(-1), + ) * (grad_output / sum_exp_logits.unsqueeze(-1)) return per_sample_loss, grad @torch.compile -def _fused_entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, +def fused_entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -239,11 +287,11 @@ def _fused_entropy_loss_forward_backward( It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, but still suboptimal because it needs multiple kernels. """ - grad_output = None if grad_output is None else grad_output / logits.size(0) * logits_scale_factor + grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor if target_format == TargetFormat.labels: assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) - if loss_mask is None: - loss_mask = target >= 0 + assert loss_mask is None + loss_mask = target >= 0 per_sample_loss, grad = _fused_cross_entropy_base_from_labels( logits, target, @@ -277,7 +325,7 @@ def _fused_entropy_loss_forward_backward( raise NotImplementedError(entropy_loss_type) if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1) + per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() if grad is not None: @@ -286,63 +334,3 @@ def _fused_entropy_loss_forward_backward( grad = grad.to(logits.dtype) return loss, grad - - -_ENTROPY_LOSS_IMPLEMENTATIONS = { - EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, - EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, - EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, -} - - -def entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - implementation: EntropyLossImplementation = EntropyLossImplementation.fused, - logits_scale_factor: float = 1.0, - temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, - entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Select the appropriate implementation of cross-entropy. - The triton implementation from the triton submodule is the fastest and recommended one. - It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, - which is faster and has a relatively small memory overhead. - """ - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - assert loss_mask is None - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, EntropyLossImplementation.fused) - return _fused_entropy_loss_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - group, - temperature, - ) - else: - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - temperature=temperature, - ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 709d0c52d..ef2039ade 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -140,7 +140,8 @@ def triton_cross_entropy_forward_backward( # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() - n_rows, n_cols = logits.shape + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) @@ -155,8 +156,8 @@ def triton_cross_entropy_forward_backward( losses, None if grad_output is None else grad_output / n_rows, n_cols, - logits.stride(0), - None if grad_output is None else grad_logits.stride(0), + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), logits_scale_factor, block_size=block_size, num_warps=num_warps, @@ -172,9 +173,9 @@ def triton_cross_entropy_forward_backward( losses, None if grad_output is None else grad_output / n_rows, n_cols, - logits.stride(0), - target.stride(0), - None if grad_output is None else grad_logits.stride(0), + logits.stride(-2), + target.stride(-2), + None if grad_output is None else grad_logits.stride(-2), logits_scale_factor, block_size=block_size, num_warps=num_warps, diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index 15c4c788c..2194c6f86 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -49,17 +49,18 @@ def dpo_loss( beta: float = 1.0, logits_scale_factor: float = 1.0, ) -> torch.Tensor: + logits = logits.float() if logits_scale_factor != 1.0: # TODO: Make more efficient. logits = logits * logits_scale_factor - policy_log_probabilities = _get_target_log_probabilities(logits, targets) + policy_log_probabilities = get_target_log_probabilities(logits, targets) policy_log_ratios = _get_target_log_probability_for_spans( policy_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets) + reference_log_probabilities = get_target_log_probabilities(reference_model_logits.float().detach(), targets) reference_log_ratios = _get_target_log_probability_for_spans( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) @@ -68,14 +69,17 @@ def dpo_loss( return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean() -def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): - # Gather log probabilities corresponding to the target tokens - return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): return sum( log_probabilities[sample_index, begin:end].sum() for sample_index, sample_spans in enumerate(spans) for begin, end in sample_spans ) + + +@torch.compile +def get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + # Avoid negative (masked) labels. + targets = targets * (targets >= 0) + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 3ae87d2e9..1dfd3920c 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -3,15 +3,17 @@ import torch from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward +from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.utils import Assert -def _get_imlementation( +def _get_implementation( default: EntropyLossImplementation = EntropyLossImplementation.auto, loss_type: EntropyLossType = EntropyLossType.cross_entropy, vocab_parallel: bool = False, @@ -34,7 +36,7 @@ def _get_imlementation( class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._implementation = _get_imlementation( + self._implementation = _get_implementation( self._config.implementation, self._config.loss_type, self._vocab_parallel ) @@ -63,7 +65,7 @@ def __init__(self, *args, **kwargs): if self._prediction_distance > 0: raise NotImplementedError() - self._implementation = _get_imlementation( + self._implementation = _get_implementation( self._config.implementation, self._config.loss_type, self._vocab_parallel ) @@ -84,3 +86,63 @@ def forward_backward( target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, ) + + +_ENTROPY_LOSS_IMPLEMENTATIONS = { + EntropyLossImplementation.torch: torch_entropy_loss_forward_backward, + EntropyLossImplementation.fused: fused_entropy_loss_forward_backward, + EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, +} + + +def entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + implementation: EntropyLossImplementation = EntropyLossImplementation.fused, + logits_scale_factor: float = 1.0, + temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Select the appropriate implementation of cross-entropy. + The triton implementation from the triton submodule is the fastest and recommended one. + It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, + which is faster and has a relatively small memory overhead. + """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + assert loss_mask is None + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + if group: + Assert.eq(implementation, EntropyLossImplementation.fused) + return fused_entropy_loss_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature, + ) + else: + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + temperature=temperature, + ) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 711560a8f..41e8942ac 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -116,6 +116,6 @@ def loss_forward_backward( grad = None else: loss.backward(torch.full_like(loss, grad_output)) - grad = input_.grad.detach().to(input_.dtype) + grad = input_.grad.detach() return loss, grad diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index c94851bf2..82b8d5318 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -2,8 +2,9 @@ import torch +from fast_llm.functional.entropy_loss import fused_softmax_base from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig -from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): @@ -19,12 +20,12 @@ def forward_backward( kwargs: dict[str, typing.Any], split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return loss_forward_backward( - self._get_grad_output(kwargs), - z_loss, + return z_loss_forward_backward( logits, self._get_loss_mask(kwargs, split_index), - self._logits_scale_factor, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + logits_scale_factor=self._logits_scale_factor, ) @@ -34,10 +35,41 @@ def z_loss( loss_mask: "torch.Tensor | None" = None, logits_scale_factor: float = 1.0, ) -> torch.Tensor: - """ - Z-loss = mean(logsumexp(logits, dim=-1) ** 2) - """ + # TODO: Replace usage in MoE, move to testing. + logits = logits.float() out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 if loss_mask is not None: out = out * loss_mask return torch.mean(out) + + +@torch.compile +def z_loss_forward_backward( + logits: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + Grad = 2 * log_sum_exp_logits * softmax(logits) + """ + grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group) + log_sum_exp_logits = sum_exp_logits.log() + logits_max + + per_sample_loss = log_sum_exp_logits**2 + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask + loss = per_sample_loss.mean() + + if grad_output is None: + grad = None + else: + grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits) + if loss_mask is not None: + grad_base = grad_base * loss_mask + grad = (grad_base.unsqueeze(-1) * exp_logits).to(logits.dtype) + + return loss, grad diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a315beecc..314741c3b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -54,6 +54,11 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_preference_spans: bool = Field( + default=False, + desc="Read dpo data (chosen and rejected spans) from the dataset.", + hint=FieldHint.feature, + ) truncate_documents: bool | None = Field( default=True, desc=( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 768d3fdd7..ded0f81c8 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,11 +33,11 @@ def _get_sampling_parameters( def _get_preprocessing_config( self, *, _return_dict: bool = False ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = { "type": "language_model", "vocab_size": self._config.model.base_model.embeddings.vocab_size, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - # OK since DPO is not supported for MTP. - "use_preference_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + "use_preference_spans": self._config.batch.use_preference_spans, } return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/tests/conftest.py b/tests/conftest.py index f93eec215..4f7d7bad0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,7 +256,7 @@ def pytest_runtest_call(item: pytest.Function): if torch.cuda.is_available(): # Empty cache to check is cuda is still working (TODO: Is there a better way? Can we kill the worker?) try: - torch.cuda.empty_cache() + torch.cuda.synchronize() except RuntimeError: pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") manager.handle_missing(item) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py deleted file mode 100644 index 35d1ef648..000000000 --- a/tests/functional/test_entropy_loss.py +++ /dev/null @@ -1,178 +0,0 @@ -import pathlib - -import pytest -import torch - -from fast_llm.engine.distributed.config import DistributedBackend -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.entropy_loss import entropy_loss_forward_backward -from fast_llm.utils import Assert -from tests.utils.subtest import DistributedTestContext - - -def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = "cuda" if torch.cuda.is_available() else "cpu" - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.float32, device=device) / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) - logits = torch.nn.functional.one_hot(target, num_columns) + logits_var - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.float32, device=device) - logits = target + logits_var - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - return logits, target, loss_mask - - -def _compare_entropy_loss_outputs( - loss: torch.Tensor, - ref_loss: torch.Tensor, - has_grad: bool, - grad: torch.Tensor | None, - ref_grad: torch.Tensor | None, - threshold=1e-5, - loss_min_threshold=1e-6, -): - Assert.rms_close_relative(loss, ref_loss, threshold, loss_min_threshold) - if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) - else: - assert grad is None - assert ref_grad is None - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), - ( - (8192, 1.0, 1.0, False), # Simple - (5000, 1.0, 1.0, False), # Not a power of 2 - (5000, None, 1.0, False), # No grad - (5000, 1.0, 4.0, False), # Loss scaling - (5000, 4.0, 1.0, False), # Grad scaling - (5000, 1.0, 1.0, True), # Loss masking - (65536, 1.0, 1.0, False), # Max block size - (65537, 1.0, 1.0, False), # Above max block size - ), -) -@pytest.mark.parametrize("target_format", TargetFormat) -@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) -def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") - # TODO: Test tensor-parallel implementation. - logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) - kwargs = { - "logits": logits, - "target": target, - "loss_mask": loss_mask, - "grad_output": grad_output, - "logits_scale_factor": logits_scale_factor, - "target_format": target_format, - "entropy_loss_type": entropy_loss_type, - } - # Torch serves as the reference implementation. - out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) - out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) - - _compare_entropy_loss_outputs( - out_fused, - out_torch, - grad_output is not None, - grad_fused, - grad_torch, - loss_min_threshold=5e-6, - ) - - if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): - # Triton implementation only supports cross-entropy. - return - assert TritonConfig.TRITON_ENABLED - if num_columns > 65536: - with pytest.raises(AssertionError): - entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) - else: - out_triton, grad_triton = entropy_loss_forward_backward( - **kwargs, implementation=EntropyLossImplementation.triton - ) - _compare_entropy_loss_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) - - -def _entropy_loss_distributed( - target_format: TargetFormat, - entropy_loss_type: EntropyLossType, - loss_masking: bool, - group: torch.distributed.ProcessGroup, -): - # Ensure all workers have the same inputs. - torch.manual_seed(0) - rank = group.rank() - world_size = group.size() - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - - kwargs = { - "loss_mask": loss_mask, - "grad_output": 1.0, - "target_format": target_format, - "implementation": EntropyLossImplementation.fused, - "entropy_loss_type": entropy_loss_type, - } - out_ref, grad_ref = entropy_loss_forward_backward(logits, target, **kwargs) - - out, grad = entropy_loss_forward_backward( - logits.chunk(world_size, 1)[rank], - target if target_format == TargetFormat.labels else target.chunk(world_size, 1)[rank], - group=group, - **kwargs, - ) - _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) - - -def _run_entropy_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path): - for entropy_loss_type in EntropyLossType: - for target_format in TargetFormat: - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - continue - for loss_masking in [False, True]: - name = f"{entropy_loss_type}_{target_format}_{loss_masking}" - with test_context.subtest(base_path, name, 2) as subtest: - if subtest.do_run: - _entropy_loss_distributed(target_format, entropy_loss_type, loss_masking, test_context.group) - - -@pytest.mark.slow -def test_entropy_loss_distributed_dependency(): - # Mock test so the distributed subtest are placed in the same dependency group. - pass - - -@pytest.mark.slow -@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) -def test_run_entropy_loss_distributed(run_parallel_script, result_path): - run_parallel_script( - _run_entropy_loss_distributed, - (result_path / "test_entropy_loss",), - world_size=2, - backend=DistributedBackend.gloo, - use_cuda=False, # Disable device count check. - ) - - -# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. -# This should still run after `test_run_entropy_loss_distributed` -@pytest.mark.slow -@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) -@pytest.mark.parametrize("target_format", TargetFormat) -@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) -@pytest.mark.parametrize("loss_masking", (False, True)) -def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") - report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 840e3846d..6471a516f 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,13 +1,10 @@ -import numpy as np import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.utils import Assert -from tests.utils.dataset import get_random_spans def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): @@ -18,59 +15,6 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans ) -def reference_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, - beta: float, -) -> torch.Tensor: - # TODO: Too similar to the actual implementation. - policy_log_probs = ( - torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - ) - policy_chosen_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - policy_rejected_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - reference_log_probs = ( - torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) - .gather(dim=-1, index=targets.unsqueeze(-1)) - .squeeze(-1) - ) - reference_chosen_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - reference_rejected_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() - - -def test_dpo_loss(): - logits = torch.normal(0, 1, (10, 50, 100)) - reference_model_logits = torch.normal(0, 1, (10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) - spans = get_random_spans(np.full(10, 50), 0, 10) - - fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2]) - reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) - Assert.rms_close(fastllm_loss, reference_loss, 1e-5) - - @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py new file mode 100644 index 000000000..639a3ba7c --- /dev/null +++ b/tests/layers/test_lm_losses.py @@ -0,0 +1,346 @@ +import contextlib +import pathlib +import random + +import numpy as np +import pytest +import torch + +from fast_llm.core.ops import split_op +from fast_llm.engine.config_utils import data_type +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedBackend +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.layers.language_model.loss.dpo import dpo_loss +from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward +from fast_llm.layers.language_model.loss.loss import loss_forward_backward +from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward +from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans +from tests.utils.subtest import DistributedTestContext + +VOCAB_SIZE = 100 +NUM_TOKENS = 200 + + +def _get_lm_loss_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat, batch_shape: tuple[int], dtype +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn((*batch_shape, num_columns), dtype=dtype.torch, device=device) / 3 + loss_mask = torch.randint(0, 2, batch_shape, dtype=torch.bool, device=device) if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, batch_shape, dtype=torch.int64, device=device) + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + target = torch.where(loss_mask, target, -100) + loss_mask = None + else: + # Target logits are typically in training precision, ex. with distillation model. + target = torch.randn((*batch_shape, num_columns), dtype=dtype.torch, device=device) + logits = target + logits_var + if target_format == TargetFormat.probabilities: + # Probabilities need to be in full precision for accuracy. + target = torch.softmax(target, -1, dtype=torch.float32) + return logits, target, loss_mask + + +def _compare_losses_and_grads( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, + threshold=1e-5, + group: torch.distributed.ProcessGroup | None = None, +): + Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) + if has_grad: + Assert.rms_close_relative( + grad, split_op(ref_grad, group, -1), threshold, 1e-8 if grad.dtype == torch.float32 else 1e-7 + ) + else: + assert grad is None + assert ref_grad is None + + +def reference_dpo_loss( + logits: torch.Tensor, + labels: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, +) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=labels.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() + + +_BATCH_SHAPES = ((64,), (16, 8)) +_LOSS_PARAMETERS = ( + (500, 1.0, 1.0, False, DataType.float32), # Simple + (512, 1.0, 1.0, False, DataType.float32), # Power of 2 + (500, None, 1.0, False, DataType.float32), # No grad + (500, 1.0, 4.0, False, DataType.float32), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32), # Loss masking + (500, 1.0, 1.0, False, DataType.float16), # Fp16 + (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16, loss masking + (65538, 1.0, 1.0, False, DataType.float32), # Above max block size +) + + +def _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + group=None, +): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + # TODO: Test tensor-parallel implementation. + logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) + # Torch serves as the reference implementation. + out_ref, grad_ref = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.torch, + ) + out_fused, grad_fused = entropy_loss_forward_backward( + logits=split_op(logits, group, -1), + target=target if target_format == TargetFormat.labels else split_op(target, group, -1), + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.fused, + ) + + _compare_losses_and_grads( + out_fused, + out_ref, + grad_output is not None, + grad_fused, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) + + if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available() or group is not None: + # Triton implementation only supports cross-entropy. + return + assert TritonConfig.TRITON_ENABLED + with pytest.raises(AssertionError) if num_columns > 65536 else contextlib.nullcontext(): + out_triton, grad_triton = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.triton, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + + +def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): + logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) + out_ref, grad_ref = loss_forward_backward( + grad_output, + z_loss, + logits, + loss_mask, + logits_scale_factor=logits_scale_factor, + ) + out_fused, grad_fused = z_loss_forward_backward( + split_op(logits, group, -1), + loss_mask, + grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + ) + _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +def test_entropy_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type, dtype +): + _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype): + _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) + + +@pytest.mark.skip(reason="DPO loss is broken") +def test_dpo_loss(): + logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + reference_model_logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + spans = get_random_spans(np.full(10, 50), 0, 10) + + fast_llm_loss = dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2]) + reference_loss = reference_dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2], beta=1) + Assert.rms_close(fast_llm_loss, reference_loss, 1e-5) + + +def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path, seed: int): + for batch_shape in _BATCH_SHAPES: + for num_columns, grad_output, logits_scale_factor, loss_masking, dtype in _LOSS_PARAMETERS: + suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}" + # Entropy loss + for entropy_loss_type in EntropyLossType: + for target_format in TargetFormat: + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + continue + with test_context.subtest( + base_path, f"{entropy_loss_type}-{target_format}-{suffix}", 2 + ) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + test_context.group, + ) + # Z loss + with test_context.subtest(base_path, f"z_loss-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_z_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + test_context.group, + ) + + +@pytest.mark.slow +def test_lm_loss_distributed_dependency(): + # Mock test so the distributed subtest are placed in the same dependency group. + pass + + +# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. +# This should still run after `test_run_entropy_loss_distributed` +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) +def test_run_lm_loss_distributed(run_parallel_script, result_path): + run_parallel_script( + _run_lm_loss_distributed, + (result_path / "test_losses", random.randint(0, 2**32 - 1)), + world_size=2, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. + ) + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +@pytest.mark.parametrize( + "loss_type", + ( + *( + f"{entropy_loss_type}-{target_format}" + for entropy_loss_type in EntropyLossType + for target_format in TargetFormat + if target_format != TargetFormat.labels or entropy_loss_type != EntropyLossType.reverse_kl + ), + "z_loss", + ), +) +def test_lm_loss_distributed( + result_path, + report_subtest, + loss_type, + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, +): + report_subtest( + result_path + / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}", + 2, + use_cuda=False, + ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 854ecec36..d1b627ecc 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -10,6 +10,7 @@ from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import padded_cumsum from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH @@ -184,6 +185,8 @@ def _get_test_dataset( hf_path = path / "hf" if not config_only and not all(config_path.is_file() for config_path in config_paths): + # Not supported for parallel tests, but dataset should already exist anyway. + assert DistributedConfig.default_world_size == 1 dataset = _get_hf_test_dataset( seed=seed, num_documents=num_documents, @@ -322,9 +325,10 @@ def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, + num_documents=200, max_loss_masking_spans=5, max_vocab_size=MODEL_TEST_VOCAB_SIZE, - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 18, "test": 2}, config_only=config_only, ) @@ -333,6 +337,7 @@ def get_multimodal_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset_multimodal", seed=1234, + num_documents=200, max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_images=2, image_patch_config=ImagePatchConfig( @@ -343,6 +348,6 @@ def get_multimodal_test_dataset(config_only: bool = False): image_break_token=None, image_end_token=None, ), - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 18, "test": 2}, config_only=config_only, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5e7526377..5a6aff831 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -986,7 +986,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + compare_factor=15.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalization layer when using rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index b6764c0e2..b8f0b5b7a 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -51,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start", device=self._pool.device) + safe_barrier(self._group, "start") return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end", device=self._pool.device) + safe_barrier(self._group, "testing end") # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -138,13 +138,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name, device=self._test_context._pool.device) - self._success = ( - allreduce_scalar( - self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device - ) - == group.size() - ) + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. @@ -165,6 +160,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): def do_run(self) -> bool: return self._do_run and not self._skip + @property + def name(self) -> str: + return self._name + def set_subtest_success(path: pathlib.Path, success: bool = True): path.joinpath("pytest_success").write_text(str(int(success))) @@ -201,7 +200,9 @@ def report_subtest(request: pytest.FixtureRequest): verbose = request.config.getoption("verbose") do_capture = request.config.getoption("distributed_capture") - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: + def do_report_subtest(path: pathlib.Path, world_size: int, use_cuda: bool = True) -> None: + if use_cuda and torch.cuda.device_count() < world_size: + pytest.skip(f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {world_size}") success = check_subtest_success(path) if not do_capture: logger.warning("Distributed capture is disabled. See distributed test for run output.")