diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 737f6fa3c..64ecbde1f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/Dockerfile b/Dockerfile index 5804d0e47..7ff5d7a74 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/data/data/data_loader.py b/fast_llm/data/data/data_loader.py index a407c0258..ba7e5e612 100644 --- a/fast_llm/data/data/data_loader.py +++ b/fast_llm/data/data/data_loader.py @@ -1,7 +1,10 @@ +import itertools import typing import torch.utils.data +from fast_llm.core.distributed import broadcast_object + class SampledDatasetIterator(torch.utils.data.Sampler): """ @@ -23,3 +26,47 @@ def __len__(self) -> int: def __iter__(self) -> typing.Iterator[list[int]]: for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): yield list(range(idx + self._start_idx, idx + self._end_idx)) + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + data_loader: torch.utils.data.dataloader.DataLoader, + process_group: torch.distributed.ProcessGroup | None, + ): + self._data_loader = data_loader + self._rank = 0 if process_group is None else process_group.rank() + self._process_group = process_group + + def __iter__(self): + if self._rank == 0: + self._iterator = iter(self._data_loader) + else: + self._iterator = itertools.repeat(None) + if self._process_group is None: + return self._iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + + try: + data = next(self._iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self._process_group, 0) + + if isinstance(data, Exception): + raise data + + return data diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 17f151919..3af86652a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,7 +8,7 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data -from fast_llm.data.data.data_loader import SampledDatasetIterator +from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters @@ -116,20 +116,23 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa - batch_sampler=SampledDatasetIterator( - total_samples=len(self._datasets[dataset_name]), - begin_index=consumed_samples, - micro_batch_size=batch_config.micro_batch_size, - data_rank=self._distributed.config.batch_data_rank, - data_parallel=self._distributed.config.batch_data_parallel, - ), - num_workers=num_workers, - prefetch_factor=prefetch_factor, - pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, - ) + data_loader = torch.utils.data.DataLoader( + self._datasets[dataset_name], # noqa + batch_sampler=SampledDatasetIterator( + total_samples=len(self._datasets[dataset_name]), + begin_index=consumed_samples, + micro_batch_size=batch_config.micro_batch_size, + data_rank=self._distributed.config.batch_data_rank, + data_parallel=self._distributed.config.batch_data_parallel, + ), + num_workers=num_workers, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=LanguageModelBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) + + if self._datasets[dataset_name].requires_broadcast: + data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group) + + return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708b..520e6d0af 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -5,6 +5,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData + from fast_llm.data.dataset.sampled import SampledIterableDataset class Dataset[SampleType: Sample](abc.ABC): @@ -27,6 +28,14 @@ def __getstate__(self): del state["__orig_class__"] return state + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return False + class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ @@ -48,3 +57,14 @@ class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass + + +class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]): + @abc.abstractmethod + def iterate(self, sampling: "SamplingData") -> typing.Iterator[SampleType]: + pass + + def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]": + from fast_llm.data.dataset.sampled import SampledIterableDataset + + return SampledIterableDataset(self, config) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 2858d8d18..8b18b59ba 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -15,6 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -298,3 +299,49 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) else: raise FileNotFoundError(self.path) + + +REDIS_DATA_STREAM = "fast_llm_streaming" +REDIS_GROUP_NAME = "fast_llm_group" + + +@config_class() +class RedisConfig(Config): + REDIS_FIELD: typing.ClassVar[str] = "data" + REDIS_FIELD_B: typing.ClassVar[bytes] = REDIS_FIELD.encode() + REDIS_GROUP_NAME: typing.ClassVar[str] = "fast_llm_group" + REDIS_GROUP_NAME_B: typing.ClassVar[bytes] = REDIS_GROUP_NAME.encode() + + # TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config? + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + def get_client(self): + import redis + + return redis.Redis(self.host, self.port) + + +@config_class(dynamic_type={SampledDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]): + """ + Configuration for a streaming dataset that reads training data from a Redis stream. + """ + + _abstract = False + + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + from fast_llm.data.dataset.streaming import RedisStreamingDataset + + return RedisStreamingDataset[StreamingDatasetConfig, SampleType](self, sampling.distributed.config).sample( + sampling + ) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 01f3195e4..27070f674 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -51,3 +51,11 @@ def __getitem__(self, index: int) -> SampleType: @property def name(self) -> str: return self._dataset.name + + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return self._dataset.requires_broadcast diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 36b52d9f8..35504df9f 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -8,7 +8,7 @@ import torch import yaml -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -111,6 +111,10 @@ def __init__( # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `Data`. + @property + def requires_broadcast(self) -> bool: + return self._indexed_dataset.requires_broadcast + def _sample(self) -> None: """ Create a `SampledDataset` with the requested parameters. @@ -429,3 +433,60 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]): + def __init__( + self, + dataset: SamplableIterableDataset[SampleType], + sampling: SamplingData, + ): + self._dataset = dataset + self._sampling = sampling + self._documents: list[SampleType] = [] + self._current_length = 0 + self._sample_length = self._sampling.parameters.sequence_length + self._sampling.parameters.extra_tokens + # Delay iterator creation to avoid pickling issues. + self._iterator: typing.Iterator[SampleType] | None = None + + @property + def requires_broadcast(self) -> bool: + # TODO: ====== fix ====== + # return self._iterator.requires_broadcast + return True + + def __getitem__(self, index: int) -> SampleType: + if self._iterator is None: + self._iterator = self._dataset.iterate(self._sampling) + while self._current_length < self._sample_length: + document = next(self._iterator) + if len(document) > self._sample_length: + logging.warning(f"Dropping document with length {len(document)} > {self._sample_length}.") + continue + self._documents.append(document) + self._current_length += len(document) + + if self._current_length == self._sample_length: + documents = self._documents + self._documents = [] + self._current_length = 0 + else: + last_length = len(self._documents[-1]) + remaining_length = last_length - (self._current_length - self._sample_length) + if self._sampling.parameters.truncate_documents: + documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)] + self._documents = [self._documents[-1].crop(remaining_length, last_length)] + else: + documents = self._documents[:-1] + [self._documents[0].get_padding(remaining_length)] + self._documents = [self._documents[-1]] + self._current_length = len(self._documents[0]) + sample = documents[0].from_documents(documents) + Assert.eq(len(sample), self._sample_length) + return sample + + def __len__(self) -> int: + return self._sampling.parameters.num_samples + + @property + def name(self) -> str: + return self._dataset.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py new file mode 100644 index 000000000..8e3719986 --- /dev/null +++ b/fast_llm/data/dataset/streaming.py @@ -0,0 +1,173 @@ +import functools +import json +import typing + +import redis +import torch.utils.data + +from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.data.dataset.abstract import SamplableIterableDataset +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, SamplingData, StreamingDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.token_data import TokenDataSample +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.utils import Assert + + +@config_class() +class RedisDocument(Config): + """ + Schema for sending and receiving documents through redis, and the associated handling code. + """ + + tokens: torch.Tensor = Field() + loss_masking_spans: list[tuple[int, int]] | None = Field(default=None) + chosen_span: tuple[int, int] | None = Field(default=None) + rejected_span: tuple[int, int] | None = Field(default=None) + advantage: float | None = Field(default=None) + old_log_probabilities: torch.Tensor | None = Field(default=None) + + def _validate(self): + # Decode message + if isinstance(self.tokens, bytes): + self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) + elif isinstance(self.tokens, (list, tuple)): + self.tokens = torch.tensor(self.tokens, dtype=torch.int64) + if isinstance(self.loss_masking_spans, str): + self.loss_masking_spans = json.loads(self.loss_masking_spans) + if isinstance(self.chosen_span, str): + self.chosen_span = json.loads(self.chosen_span) + if isinstance(self.rejected_span, str): + self.rejected_span = json.loads(self.rejected_span) + if isinstance(self.old_log_probabilities, bytes): + self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) + elif isinstance(self.old_log_probabilities, (list, tuple)): + self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) + super()._validate() + if self.old_log_probabilities is not None: + Assert.eq(len(self.old_log_probabilities), self.num_tokens) + + @functools.cached_property + def num_tokens(self) -> int: + return len(self.tokens) + + @classmethod + def from_message(cls, message: dict[bytes, bytes]) -> typing.Self: + # Read + kwargs = {} + for key, value in message.items(): + key = key.decode() + if key == "data": + kwargs.update(json.loads(value)) + else: + kwargs[key] = value + return cls.from_dict(kwargs) + + def to_message(self) -> dict[str, str | int | float | bytes]: + # Encode message + message: dict[str, str | int | float | bytes] = {"tokens": self.tokens.numpy().tobytes()} + if self.old_log_probabilities is not None: + message["old_log_probabilities"] = self.old_log_probabilities.numpy().tobytes() + data = {} + if self.loss_masking_spans is not None: + data["loss_masking_spans"] = self.loss_masking_spans + if self.chosen_span is not None: + data["chosen_span"] = self.chosen_span + if self.rejected_span is not None: + data["rejected_span"] = self.rejected_span + if self.advantage is not None: + data["advantage"] = self.advantage + if data: + message["data"] = json.dumps(data) + return message + + def to_sample(self, preprocessing: LanguageModelPreprocessingConfig | None): + sample_size = len(self.tokens) + # TODO: Check explicitly that required data is available? + return LanguageModelSample( + tokens=TokenSample(self.tokens, [sample_size]), + loss_masking_spans=( + RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size) + if preprocessing.use_loss_masking_spans + else None + ), + chosen_spans=RangeSample([self.chosen_span], sample_size) if preprocessing.use_preference_spans else None, + rejected_spans=( + RangeSample([self.rejected_span], sample_size) if preprocessing.use_preference_spans else None + ), + advantages=( + TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32)) + if preprocessing.use_advantages + else None + ), + old_log_probabilities=( + TokenDataSample(self.old_log_probabilities) if preprocessing.use_old_log_probabilities else None + ), + ) + + +class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( + Configurable[ConfigType], SamplableIterableDataset[SampleType] +): + def __init__(self, config: ConfigType, distributed_config: DistributedConfig): + super().__init__(config) + self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_STREAM}|{REDIS_GROUP_NAME})[data]" + self._config = config + self._rank = distributed_config.batch_data_rank + self.is_batch_data_group_leader = ( + distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank == 0 + ) + + @property + def requires_broadcast(self) -> bool: + return True + + @property + def name(self) -> str: + return self._name + + def iterate(self, sampling: SamplingData) -> typing.Iterator[SampleType]: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError("StreamingDataset can work only with one instance per rank") + + if not self.is_batch_data_group_leader: + raise RuntimeError("Must be only called on the batch data group leader") + + client = redis.Redis(host=self._config.host, port=self._config.port) + + # Create the consumer group at the start of the stream ("0") + # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True + try: + client.xgroup_create(name=REDIS_DATA_STREAM, groupname=REDIS_GROUP_NAME, id="0", mkstream=True) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + # Consumer group already exists + pass + else: + raise + + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = client.xreadgroup( + groupname=REDIS_GROUP_NAME, + consumername=f"fast_llm_consumer_{self._rank}", + # ">" reads only new messages that have not been delivered to any consumer + streams={REDIS_DATA_STREAM: ">"}, + count=1, + block=1000, + # No explicit ACK: messages are processed immediately; on rank failure the job restarts, + # so message loss is acceptable and simplifies coordination + noack=True, + ) + if messages: + for stream_key, messages_ in messages: + assert stream_key == REDIS_DATA_STREAM.encode() + for message_id, message in messages_: + print(message) + yield RedisDocument.from_message(message).to_sample(sampling.preprocessing) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index a1aadf40a..1645e4dea 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -48,6 +48,10 @@ def has_preference_spans(self) -> bool: def has_images(self) -> bool: return False + @functools.cached_property + def has_grpo_data(self) -> bool: + return False + @config_class(dynamic_type={LanguageModelSourceConfig: "document"}) class DocumentSourceConfig(LanguageModelSourceConfig): @@ -91,6 +95,13 @@ class DocumentSourceConfig(LanguageModelSourceConfig): desc="Field containing image positions in the text.", hint=FieldHint.optional, ) + # TODO: Old log probabilities are made up (zeros) since we don't know the token count in advance. + advantages: str | None = Field( + default=None, + desc="Field containing advantaged for policy optimization." + " Mainly for debugging purposed as advantages are typically generated at runtime.", + hint=FieldHint.optional, + ) @functools.cached_property def columns(self) -> list[str]: @@ -117,6 +128,10 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None + @functools.cached_property + def has_grpo_data(self) -> bool: + return self.advantages is not None + def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 325d33c43..7fc2e35af 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -42,6 +42,7 @@ from fast_llm.data.sample.patch import PatchSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.token_data import TokenDataSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import normalize_probabilities, padded_cumsum @@ -222,11 +223,14 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ), use_loss_masking_spans=self._source_schema.has_loss_masking_span, use_preference_spans=self._source_schema.has_preference_spans, + use_grpo_data=self._source_schema.has_grpo_data, ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: token_spans_by_type = collections.defaultdict(list) - image_patches = image_token_maps = image_position_ids = patch_counts = None + image_patches = image_token_maps = image_position_ids = patch_counts = advantages = old_log_probabilities = ( + None + ) if isinstance(self._source_schema, ConversationSourceConfig): # Conversation format: tokenize messages and get loss masking spans from chat template @@ -332,31 +336,39 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: else: raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") + if self._source_schema.has_grpo_data: + advantages = torch.full_like(tokens, sample[self._source_schema.advantages], dtype=torch.float32) + old_log_probabilities = torch.zeros_like(tokens, dtype=torch.float32) + sample_size = len(tokens) return LanguageModelSample( TokenSample(tokens, [sample_size]), - ( + loss_masking_spans=( RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) if self._source_schema.has_loss_masking_span else None ), - ( + chosen_spans=( RangeSample(token_spans_by_type[SpanType.chosen], sample_size) if self._source_schema.has_preference_spans else None ), - ( + rejected_spans=( # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) if self._source_schema.has_preference_spans else None ), - ( + image_patches=( PatchSample(image_patches, image_token_maps, image_position_ids, sample_size, patch_counts) if self._source_schema.has_images else None ), + advantages=TokenDataSample(advantages) if self._source_schema.has_grpo_data else None, + old_log_probabilities=( + TokenDataSample(old_log_probabilities) if self._source_schema.has_grpo_data else None + ), ) def generate_config_yaml_for_sharded_dst( diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d54776eec..a4b4a2ff8 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -22,6 +22,7 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) + use_grpo_data: bool = Field(default=False) def _validate(self) -> None: super()._validate() @@ -32,6 +33,14 @@ def _validate(self) -> None: def use_image_patches(self) -> bool: return isinstance(self.image_patches, ImagePatchConfig) + @functools.cached_property + def use_advantages(self) -> bool: + return self.use_grpo_data + + @functools.cached_property + def use_old_log_probabilities(self) -> bool: + return self.use_grpo_data + def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 22b89acf1..aa09e1467 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -39,6 +39,13 @@ RangeWriter, ) from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.data.sample.token_data import ( + TokenDataBatch, + TokenDataReader, + TokenDataReaderConfig, + TokenDataSample, + TokenDataWriter, +) from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -53,12 +60,16 @@ def __init__( chosen_spans: RangeSample | None = None, rejected_spans: RangeSample | None = None, image_patches: PatchSample | None = None, + advantages: TokenDataSample | None = None, + old_log_probabilities: TokenDataSample | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -68,6 +79,10 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), _merge_optional(PatchSample.from_documents, [document.image_patches for document in documents]), + _merge_optional(TokenDataSample.from_documents, [document.advantages for document in documents]), + _merge_optional( + TokenDataSample.from_documents, [document.old_log_probabilities for document in documents] + ), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -77,18 +92,22 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return LanguageModelSample( + return self.__class__( self.tokens.get_padding(size), None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), None if self.chosen_spans is None else self.chosen_spans.get_padding(size), None if self.rejected_spans is None else self.rejected_spans.get_padding(size), None if self.image_patches is None else self.image_patches.get_padding(size), + None if self.advantages is None else self.advantages.get_padding(size), + None if self.old_log_probabilities is None else self.old_log_probabilities.get_padding(size), ) @@ -100,12 +119,16 @@ def __init__( chosen_spans: RangeBatch | None = None, rejected_spans: RangeBatch | None = None, image_patches: PatchBatch | None = None, + advantages: TokenDataBatch | None = None, + old_log_probabilities: TokenDataBatch | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: @@ -115,6 +138,8 @@ def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.S _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.advantages for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.old_log_probabilities for sample in samples]), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -124,6 +149,8 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def to_device_(self, device: "torch.device | str"): @@ -136,6 +163,10 @@ def to_device_(self, device: "torch.device | str"): self.rejected_spans.to_device_(device) if self.image_patches is not None: self.image_patches.to_device_(device) + if self.advantages is not None: + self.advantages.to_device_(device) + if self.old_log_probabilities is not None: + self.old_log_probabilities.to_device_(device) def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: @@ -157,6 +188,8 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() image_patches: MemmapReaderBaseConfig = Field() + advantages: MemmapReaderBaseConfig = Field() + old_log_probabilities: MemmapReaderBaseConfig = Field() def _validate(self) -> None: super()._validate() @@ -192,6 +225,16 @@ def _validate(self) -> None: self.rejected_spans, RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, ) + Assert.custom( + isinstance, + self.advantages, + TokenDataReaderConfig if self.preprocessing.use_advantages else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.old_log_probabilities, + TokenDataReaderConfig if self.preprocessing.use_old_log_probabilities else NullReaderConfig, + ) if self.preprocessing.use_image_patches: Assert.custom(isinstance, self.image_patches, PatchReaderConfig) Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) @@ -222,6 +265,8 @@ def _expected_buffer_size(self) -> int: + self.chosen_spans.expected_buffer_size + self.rejected_spans.expected_buffer_size + self.image_patches.expected_buffer_size + + self.advantages.expected_buffer_size + + self.old_log_probabilities.expected_buffer_size ) def get_metadata(self) -> dict[str, typing.Any]: @@ -235,6 +280,10 @@ def get_metadata(self) -> dict[str, typing.Any]: out["rejected_spans"] = self.rejected_spans.get_metadata() if not isinstance(self.image_patches, NullReaderConfig): out["image_patches"] = self.image_patches.get_metadata() + if not isinstance(self.advantages, NullReaderConfig): + out["advantages"] = self.advantages.get_metadata() + if not isinstance(self.old_log_probabilities, NullReaderConfig): + out["old_log_probabilities"] = self.old_log_probabilities.get_metadata() return out @classmethod @@ -257,6 +306,12 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi out["image_patches"] = PatchReaderConfig.blend_metadata( [metadata_["image_patches"] for metadata_ in metadata] ) + if "advantages" in metadata[0]: + out["advantages"] = RangeReaderConfig.blend_metadata([metadata_["advantages"] for metadata_ in metadata]) + if "old_log_probabilities" in metadata[0]: + out["old_log_probabilities"] = RangeReaderConfig.blend_metadata( + [metadata_["old_log_probabilities"] for metadata_ in metadata] + ) return out @@ -290,6 +345,10 @@ def __init__( self._chosen_spans = self._config.chosen_spans.get_reader(buffer) self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + if self._model_preprocessing.use_advantages: + self._advantages = self._config.advantages.get_reader(buffer) + self._old_log_probabilities = self._config.old_log_probabilities.get_reader(buffer) + if self._model_preprocessing.use_image_patches: model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches if isinstance(self._config.image_patches, NullReaderConfig): @@ -334,6 +393,12 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: else None ), image_patches, + (self._advantages.get_document(index, begin, end) if self._model_preprocessing.use_advantages else None), + ( + self._old_log_probabilities.get_document(index, begin, end) + if self._model_preprocessing.use_old_log_probabilities + else None + ), ) def get_document_sizes(self) -> torch.Tensor: @@ -356,6 +421,10 @@ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dic metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + if hasattr(self, "_advantages") and isinstance(self._advantages, TokenDataReader): + metadata["advantages"] = self._advantages.get_split(begin_index, end_index) + if hasattr(self, "_old_log_probabilities") and isinstance(self._old_log_probabilities, TokenDataReader): + metadata["old_log_probabilities"] = self._old_log_probabilities.get_split(begin_index, end_index) return begin_index, end_index, metadata @@ -379,6 +448,12 @@ def __enter__(self): self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() if self._preprocessing_config.use_image_patches: self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + if self._preprocessing_config.use_advantages: + self._advantages_writer = TokenDataWriter(self._path.joinpath("advantages")).__enter__() + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer = TokenDataWriter( + self._path.joinpath("old_log_probabilities") + ).__enter__() return self def write(self, document: LanguageModelSample): @@ -403,6 +478,14 @@ def write(self, document: LanguageModelSample): assert document.image_patches is not None self._image_patches_writer.write(document.image_patches) + if self._preprocessing_config.use_advantages: + assert document.advantages is not None + self._advantages_writer.write(document.advantages) + + if self._preprocessing_config.use_old_log_probabilities: + assert document.old_log_probabilities is not None + self._old_log_probabilities_writer.write(document.old_log_probabilities) + def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_loss_masking_spans: @@ -412,6 +495,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_image_patches: self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_advantages: + self._advantages_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. @@ -446,6 +533,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): config.image_patches.begin, config.image_patches.end, ) + if self._preprocessing_config.use_advantages: + _copy_chunked( + self._path.joinpath("advantages"), + self._stream, + config.advantages.begin, + config.advantages.end, + ) + if self._preprocessing_config.use_old_log_probabilities: + _copy_chunked( + self._path.joinpath("old_log_probabilities"), + self._stream, + config.old_log_probabilities.begin, + config.old_log_probabilities.end, + ) self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) @@ -475,6 +576,16 @@ def _get_config(self, begin: int, end: int | None): offset = image_patches.end else: image_patches = NullReaderConfig() + if self._preprocessing_config.use_advantages: + advantages = self._advantages_writer.get_config(offset) + offset = advantages.end + else: + advantages = NullReaderConfig() + if self._preprocessing_config.use_old_log_probabilities: + old_log_probabilities = self._old_log_probabilities_writer.get_config(offset) + offset = old_log_probabilities.end + else: + old_log_probabilities = NullReaderConfig() if end is None: end = offset + len(LanguageModelReaderConfig.footer) @@ -488,6 +599,8 @@ def _get_config(self, begin: int, end: int | None): rejected_spans=rejected_spans, image_patches=image_patches, preprocessing=self._preprocessing_config, + advantages=advantages, + old_log_probabilities=old_log_probabilities, ) diff --git a/fast_llm/data/sample/token_data.py b/fast_llm/data/sample/token_data.py new file mode 100644 index 000000000..cf094bd4d --- /dev/null +++ b/fast_llm/data/sample/token_data.py @@ -0,0 +1,189 @@ +import functools +import math +import typing + +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.sample.abstract import ( + Batch, + MemmapReader, + MemmapReaderBase, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.data.sample.patch import PatchReaderBaseConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, get_unique + + +class TokenDataSample(Sample): + """ + A reusable component holding tensor-valued data of fixed dtype and shape for each token. + TODO: Use as base class for `TokenSample` and `PatchSample`? + """ + + def __init__(self, data: torch.Tensor): + self.data = data + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls(torch.cat([document.data for document in documents])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[begin:end]) + + def __len__(self) -> int: + return len(self.data) + + def get_padding(self, size: int) -> typing.Self: + return self.__class__(torch.full([size], 0, dtype=self.data.dtype)) + + +class TokenDataBatch(Batch): + def __init__(self, data: torch.Tensor) -> None: + self.data = data + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenDataSample]) -> typing.Self: + return cls(torch.stack([sample.data for sample in samples])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[:, begin:end]) + + def to_device_(self, device: "torch.device | str"): + self.data = self.data.to(device, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token_data"}) +class TokenDataReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"token data begin" + footer: typing.ClassVar[bytes] = b"token data end" + num_documents: int = Field() + num_tokens: int = Field() + shape: tuple[int, ...] = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @functools.cached_property + def size(self) -> int: + return math.prod(self.shape) + + @property + def reader_class(self) -> "type[TokenDataReader]": + return TokenDataReader + + @property + def writer_class(self) -> "type[TokenDataWriter]": + return TokenDataWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.num_tokens * self.data_type.torch.itemsize * self.size + + (self.num_documents + 1) * torch.int64.itemsize + ) + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + "shape": self.shape, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + "shape": get_unique(metadata_["shape"] for metadata_ in metadata), + } + + +class TokenDataReader[ConfigType: TokenDataReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._data = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens * self._config.size, + ).view(-1, *self._config.shape) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._data.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenDataSample(self._data[begin_ + begin : begin_ + end]) + + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + + return { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + "shape": self._config.shape, + } + + +class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + # TODO: Does this make sense? + return TokenDataSample(torch.zeros(end - begin, *self._config.shape, dtype=self._config.data_type.torch)) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +class TokenDataWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + self._shape = None + return self + + def write(self, document: TokenDataSample): + super().write(document) + if self._data_type is None: + self._data_type = document.data.dtype + self._shape = document.data.shape[1:] + else: + Assert.eq(self._data_type, document.data.dtype) + Assert.eq(self._shape, document.data.shape[1:]) + self._stream.write(document.data.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.data)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[TokenDataReaderConfig]: + return TokenDataReaderConfig + + def _get_config(self, begin: int, end: int): + return TokenDataReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + shape=self._shape, + data_type=DataType.from_torch(self._data_type), + preprocessing=self._preprocessing_config, + ) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 7a257a5fa..32eea2db6 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -15,6 +15,7 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig, + CheckpointStateSaveConfigBase, FastLLMCheckpointFormat, export_safetensors_metadata, ) @@ -71,6 +72,31 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + def iter_tensors( + self, config: CheckpointStateSaveConfigBase, metadata: "CheckpointMetadata" + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `state_dict` until that tensor is available. + state_dict = {} + for parameter_name, shard_name, tensor in self._model.get_state_tensor_iterator( + self.get_shard_names(config), config.data_type + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for exported_name, exported_tensor in self._convert_state_dict(shard_state_dict, True).items(): + yield shard_name, self._get_key(exported_name, shard_name), exported_tensor + + for shard_name, shard_state_dict in state_dict.items(): + assert ( + not shard_state_dict + ), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}" + @classmethod @abc.abstractmethod def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index ccde838e8..48c2d3c55 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,9 +1,11 @@ import logging import typing +import torch + from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointStateSaveConfigBase from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel @@ -30,6 +32,20 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) + def iter_checkpoint( + self, + config: CheckpointStateSaveConfigBase, + extra_metadata: dict | None = None, + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # TODO: Handle barriers, ok file, mkdir, etc. here + converter = config.format.get_handler_class()(self) + fast_llm_metadata = self._config.to_metadata( + config, + shards=converter.get_shard_names(config), + metadata={} if extra_metadata is None else extra_metadata, + ) + yield from converter.iter_tensors(config, fast_llm_metadata) + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 867cca984..0b492703c 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -7,6 +7,7 @@ from fast_llm.config import ( Config, + Configurable, Field, FieldHint, FieldUpdate, @@ -16,6 +17,7 @@ skip_valid_if_none, ) from fast_llm.data.data.config import DataConfig +from fast_llm.data.dataset.config import RedisConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -24,6 +26,7 @@ ) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig @@ -32,6 +35,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel + from fast_llm.engine.training.streaming import StreamingTrainerCallback from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @@ -321,6 +326,65 @@ def _validate(self) -> None: self.wandb.alert.assert_sub_interval(self.logs) +@config_class(registry=True) +class TrainerCallbackConfig(Config): + def get_callback(self, model: "FastLLMModel") -> "TrainerCallback": + raise NotImplementedError() + + def setup(self, config: "TrainerConfig") -> None: + pass + + +@config_class() +class WeightsBroadcastConfig(Config): + # TODO: Have the external model send these instead? + host: str = Field( + default="localhost", + desc="Master address for the external NCCL process group.", + hint=FieldHint.feature, + ) + port: int = Field( + default=23456, + desc="Master port for the external NCCL process group.", + hint=FieldHint.feature, + ) + external_world_size: int = Field( + default=1, + desc="World size of the external NCCL process group.", + hint=FieldHint.feature, + ) + backend: DistributedBackend = Field( + default=DistributedBackend.nccl, + desc="Backend for the external NCCL process group.", + hint=FieldHint.feature, + ) + + +@config_class(dynamic_type={TrainerCallbackConfig: "streaming"}) +class StreamingTrainerCallbackConfig(TrainerCallbackConfig, RedisConfig): + """ + Aggregates all trainer-side Redis-based event configurations. + """ + + broadcast: WeightsBroadcastConfig = Field( + desc="Configuration for signaling weight-ready events via Redis.", + hint=FieldHint.core, + ) + + export: CheckpointStateSaveConfigBase = Field( + desc="Configuration for exporting checkpoints before broadcasting them.", + hint=FieldHint.core, + ) + + def get_callback(self, model: "FastLLMModel") -> "StreamingTrainerCallback": + from fast_llm.engine.training.streaming import StreamingTrainerCallback + + return StreamingTrainerCallback(self, model) + + def setup(self, config: "TrainerConfig") -> None: + self.export.setup(config.model) + + @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True @@ -352,10 +416,19 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) + callbacks: dict[str, TrainerCallbackConfig] = Field( + default_factory=dict, + desc="Configuration for training callbacks.", + hint=FieldHint.feature, + ) + def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): self._add_reference_distributed_to_pretrained(reference_model) + for callback in self.callbacks.values(): + # We don't know anything about the callbacks, so we forward `self` and let them handle their own setup. + callback.setup(self) super()._validate() if self.reference_models: # TODO: Add support. @@ -403,3 +476,21 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) + + +class TrainerCallback[ConfigType: TrainerCallbackConfig](Configurable[ConfigType]): + # TODO: Make a more exhaustive set of events and arguments. + def run_begin(self, step: int): + pass + + def step_end( + self, + step: int, + reduced_losses: dict[str, float | int], + update_successful: bool, + train_metrics: dict[str, typing.Any] | None, + ): + pass + + def train_end(self, step: int): + pass diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py new file mode 100644 index 000000000..098a63efe --- /dev/null +++ b/fast_llm/engine/training/streaming.py @@ -0,0 +1,77 @@ +import json +import logging +import typing + +import torch.distributed + +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.training.config import StreamingTrainerCallbackConfig, TrainerCallback + +logger = logging.getLogger(__name__) + + +REDIS_TRAINING_STREAM = "fast_llm_events" +REDIS_TRAINING_FIELD = "event" + + +class StreamingTrainerCallback[ConfigType: StreamingTrainerCallbackConfig](TrainerCallback[ConfigType]): + def __init__(self, config: ConfigType, model: "FastLLMModel"): + super().__init__(config) + self._model = model + self._do_broadcast = self._model.config.distributed.rank == 0 + if self._do_broadcast: + self._client = self._config.get_client() + init_method = f"tcp://{config.broadcast.host}:{config.broadcast.port}" + logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + # TODO: Create a custom process group instead. + self._process_group = torch.distributed.init_process_group( + backend=str(self._config.broadcast.backend), + init_method=init_method, + world_size=config.broadcast.external_world_size + 1, + rank=0, + ) + logger.info(f"Weights broadcast rendezvous at {init_method} connected") + + def run_begin(self, step: int): + # TODO: ====== Send a train / run begin signal? ====== + self._broadcast_weights(step) + + def step_end( + self, + step: int, + reduced_losses: dict[str, float | int], + update_successful: bool, + train_metrics: dict[str, typing.Any] | None, + ): + if update_successful: + self._broadcast_weights(step) + + def train_end(self, step: int): + # TODO: ====== Send something on unsuccessful ends? ====== + if self._do_broadcast: + self._client.xadd(REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "training_finished"})}) + self._clear() + + def __del__(self): + self._clear() + + def _clear(self): + if hasattr(self, "_process_group"): + torch.distributed.destroy_process_group(self._process_group) + del self._process_group + + def _broadcast_weights(self, step: int): + if self._do_broadcast: + self._client.xadd( + REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "weights_ready", "step": step})} + ) + for shard_name, layer_name, tensor in self._model.iter_checkpoint(self._config.export, {}): + if self._do_broadcast: + # TODO: ====== Broadcast metadata in advance ======= + meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] + torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0) + torch.distributed.broadcast(tensor, group=self._process_group, group_src=0) + # Broadcast end of weights broadcast + if self._do_broadcast: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b35733cc7..d09e00d7f 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -151,6 +151,9 @@ def __init__(self, config: TrainerConfig): distributed_config=self._config.model.distributed, ) self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() + self._callbacks = { + name: config.get_callback(self._multi_stage) for name, config in self._config.callbacks.items() + } if not self._is_evaluation_only: steps_per_split = { @@ -286,6 +289,8 @@ def run(self) -> None: assert self._is_setup with self._wandb: self._run_training() + for callback in self._callbacks.values(): + callback.train_end(self._completed_steps) def _run_training(self) -> None: self._prepare_training_state() @@ -358,6 +363,10 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") + + for callback in self._callbacks.values(): + callback.run_begin(self._completed_steps) + if torch.cuda.is_available(): torch.cuda.synchronize() start_time = time.perf_counter() @@ -389,6 +398,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: skipped_iters += 1 nan_iters += not all(math.isfinite(loss) for loss in reduced_losses.values()) + for callback in self._callbacks.values(): + callback.step_end(self._completed_steps, reduced_losses, update_successful, train_metrics) # Logging. metrics = {} if is_logging: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f531a1d46..a6057d67f 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -14,6 +14,7 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -166,3 +167,18 @@ def loss_class(self) -> "type[LanguageModelZLoss]": from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss return LanguageModelZLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelLossConfig): + + _abstract: typing.ClassVar[bool] = False + + epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") + epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + + @property + def loss_class(self) -> "type[LanguageModelGRPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + return LanguageModelGRPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py new file mode 100644 index 000000000..44338306e --- /dev/null +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -0,0 +1,87 @@ +import typing + +import torch + +from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + + +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._prediction_distance > 0: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return grpo_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], kwargs, split_index), + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + ) + + +@torch.compile +def grpo_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + loss_mask = target >= 0 + + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + probability_ratio = (predicted_logits - sum_exp_logits.log() - old_log_probabilities).exp() + + per_sample_loss = -torch.min( + probability_ratio * advantages, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, + ) + per_sample_loss = per_sample_loss * loss_mask + loss = per_sample_loss.mean() + + if grad_output is None: + grad = None + else: + # loss[a>=0] = -a * min(x, 1 + epsilon_high) => grad[a>=0] = -a * (x <= 1 + epsilon_high) + # loss[a<=0] = a * max(x, 1 - epsilon_low) => grad[a<=0] = a * (x >= 1 - epsilon_low) + probability_ratio_grad = ( + grad_output + * ( + torch.clamp_min(advantages, 0) * (probability_ratio <= 1 + epsilon_high) + + torch.clamp_max(advantages, 0) * (probability_ratio >= 1 - epsilon_low) + ) + * loss_mask + ) + + # d(probability_ratio)/d(logits) = - probability_ratio * (predicted_probabilities - target_probabilities) + # (Sign absorbed in probability_ratio_grad) + predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze_(-1) + grad = (probability_ratio_grad * probability_ratio).unsqueeze(-1) * predicted_probabilities.scatter_add( + -1, + target_masked.unsqueeze(-1), + -(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32), + ) + grad = grad.to(logits.dtype) + + return loss, grad diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 314741c3b..4e2522fc1 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -59,6 +59,11 @@ class GPTBatchConfig(BatchConfig): desc="Read dpo data (chosen and rejected spans) from the dataset.", hint=FieldHint.feature, ) + use_grpo_data: bool = Field( + default=False, + desc="Read grpo data (advantages and old log probabilities) from the dataset.", + hint=FieldHint.feature, + ) truncate_documents: bool | None = Field( default=True, desc=( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd2932984..eb694a8a6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -293,6 +293,22 @@ def preprocess_batch( labels_begin, labels_end ).ranges + if batch.advantages is not None: + kwargs[LanguageModelKwargs.advantages] = batch.advantages.crop(labels_begin, labels_end).data + if kwargs[AttentionKwargs.sequence_first]: + kwargs[LanguageModelKwargs.advantages] = ( + kwargs[LanguageModelKwargs.advantages].transpose(0, 1).contiguous() + ) + + if batch.old_log_probabilities is not None: + kwargs[LanguageModelKwargs.old_log_probabilities] = batch.old_log_probabilities.crop( + labels_begin, labels_end + ).data + if kwargs[AttentionKwargs.sequence_first]: + kwargs[LanguageModelKwargs.old_log_probabilities] = ( + kwargs[LanguageModelKwargs.old_log_probabilities].transpose(0, 1).contiguous() + ) + tokens = ( cropped_tokens.tokens.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index ded0f81c8..c07810e34 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -39,5 +39,6 @@ def _get_preprocessing_config( "vocab_size": self._config.model.base_model.embeddings.vocab_size, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "use_preference_spans": self._config.batch.use_preference_spans, + "use_grpo_data": self._config.batch.use_grpo_data, } return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/setup.cfg b/setup.cfg index 867e1da29..955702907 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,6 +59,9 @@ SSM = GENERATION = lm_eval>=0.4.9 +STREAMING = + redis>=7.1.0 + # Required for supporting vision inputs VISION = # Vision Tools @@ -76,6 +79,7 @@ DEV = setuptools>=80.9.0 # Dependency manager needs colorama to show colors. colorama>=0.4.6 + fakeredis>=2.32.1 # Required for building the documentation DOCS = diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 000000000..a39228f44 --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,288 @@ +import contextlib +import logging +import pathlib +import typing + +import fakeredis +import pytest +import redis +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, RedisConfig, SamplingParameters, StreamingDatasetConfig +from fast_llm.data.dataset.streaming import RedisDocument, RedisStreamingDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert +from tests.conftest import WorkerResources +from tests.data.common import get_sampling_data +from tests.utils.redis import make_sampling, redis_batch_producer +from tests.utils.subtest import DistributedTestContext + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def fake_redis(monkeypatch): + """Monkeypatch redis.Redis globally.""" + fake_redis = fakeredis.FakeRedis() + monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) + try: + yield fake_redis + finally: + fake_redis.close() + + +@pytest.mark.parametrize( + ("documents", "preprocessing"), + [ + ((range(3),), {}), + ((range(3), range(3, 6)), {}), + ((range(3), range(5), [9, 4]), {}), + (({"tokens": list(range(5)), "loss_masking_spans": [(0, 1), (2, 3)]},), {"use_loss_masking_spans": True}), + ( + ({"tokens": list(range(8)), "chosen_span": (0, 2), "rejected_span": (3, 5)},), + {"use_preference_spans": True}, + ), + ( + ( + {"tokens": list(range(3)), "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, + {"tokens": list(range(4)), "advantage": 0.7, "old_log_probabilities": [1, 2, 3, 4]}, + ), + {"use_grpo_data": True}, + ), + ], +) +def test_streaming_dataset( + fake_redis: fakeredis.FakeRedis, + documents: tuple[list[int] | dict[str, typing.Any], ...], + preprocessing: dict, + worker_resources: WorkerResources, +): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) + dataset_iterator = RedisStreamingDataset(stream_config, DistributedConfig()).iterate( + get_sampling_data(len(documents), preprocessing=LanguageModelPreprocessingConfig.from_dict(preprocessing)) + ) + documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents] + for document in documents: + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict(document).to_message()) + for document in documents: + sample = next(dataset_iterator) + assert isinstance(sample, LanguageModelSample) + Assert.eq(sample.tokens.tokens.tolist(), document["tokens"]) + Assert.eq(sample.tokens.lengths, [len(document["tokens"])]) + + if "loss_masking_spans" in document: + Assert.eq(sample.loss_masking_spans.ranges, document["loss_masking_spans"]) + else: + assert sample.loss_masking_spans is None + + if "chosen_span" in document: + Assert.eq(sample.chosen_spans.ranges, [document["chosen_span"]]) + else: + assert sample.chosen_spans is None + + if "rejected_span" in document: + Assert.eq(sample.rejected_spans.ranges, [document["rejected_span"]]) + else: + assert sample.rejected_spans is None + + assert sample.image_patches is None + + if "advantage" in document: + Assert.rms_close( + sample.advantages.data, torch.full([len(document["tokens"])], document["advantage"]), 1e-8 + ) + else: + assert sample.advantages is None + + if "old_log_probabilities" in document: + Assert.rms_close(sample.old_log_probabilities.data, torch.tensor(document["old_log_probabilities"]), 1e-8) + else: + assert sample.old_log_probabilities is None + + +@pytest.mark.parametrize( + ("messages", "expected_samples", "expected_lengths"), + [ + ((range(5),), (range(5),), ([5],)), # Single message, exact fit. + ((range(3), [3, 4]), (range(5),), ([3, 2],)), # Two messages, exact fit. + ((range(6), range(5)), (range(5),), ([5],)), # Two messages, one dropped. + ( + (range(3), range(5)), + ( + [0, 1, 2, -100, -100], + range(5), + ), + ( + [3, 2], + [5], + ), + ), # Two messages, one padded. + ], +) +def test_streaming_sampled_dataset( + fake_redis: fakeredis.FakeRedis, + messages: tuple[list[int], ...], + expected_samples: tuple[list[int], ...], + expected_lengths: tuple[int, ...], + worker_resources: WorkerResources, +): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) + distributed = Distributed(DistributedConfig(use_cuda=False)) + dataset_iterator = iter( + RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) + ) + for message in messages: + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict({"tokens": list(message)}).to_message()) + for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): + sample = next(dataset_iterator) + assert isinstance(sample, LanguageModelSample) + Assert.eq(sample.tokens.tokens.tolist(), list(expected_sample)) + Assert.eq(sample.tokens.lengths, expected_lengths_) + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +_NUM_BATCHES = 1 + + +def _get_distributed_and_batch_config( + distributed_config_dict: dict[str, typing.Any], world_size: int = 1 +) -> tuple[DistributedConfig, GPTBatchConfig]: + distributed_config = DistributedConfig.from_dict( + distributed_config_dict, + { + "world_size": world_size, + "local_world_size": world_size, + "use_cuda": False, + "backend": DistributedBackend.gloo, + }, + ) + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10) + batch_config.setup(distributed_config=distributed_config) + batch_config.validate() + return distributed_config, batch_config + + +def _run_test_data_streaming( + path: pathlib.Path, distributed_config: DistributedConfig, batch_config: GPTBatchConfig, port: int +): + redis_config = RedisConfig(port=port + 100) + + data = GPTData(GPTDataConfig(datasets={"train": {"type": "streaming", "port": port + 100}}), distributed_config) + distributed = Distributed(distributed_config) + with ( + redis_batch_producer(redis_config, batch_config) if distributed_config.rank == 0 else contextlib.nullcontext() + ): + data.setup( + distributed=distributed, + sampling_parameters={ + "train": SamplingParameters( + sequence_length=batch_config.sequence_length, + extra_tokens=0, + num_samples=batch_config.batch_size * _NUM_BATCHES, + truncate_documents=False, + ) + }, + preprocessing=LanguageModelPreprocessingConfig(), + cache_directory=path / "cache", + timeout=5, + ) + + data_iter = data.get_iterator(batch_config, "train", consumed_samples=0, num_workers=0, prefetch_factor=None) + batches = [next(data_iter) for _ in range(_NUM_BATCHES)] + path.mkdir(parents=True, exist_ok=True) + torch.save( + torch.stack([batch.tokens.tokens[:, 0] for batch in batches]), + path / f"rank_{distributed_config.batch_data_rank}_" + f"{distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank}.pt", + ) + # Wait for other processes to finish before shutting down the server. + safe_barrier(distributed.world_group, "streaming test end") + + +def check_data_streaming_results( + path: pathlib.Path, + distributed_config: DistributedConfig, + batch_config: GPTBatchConfig, +): + sample_indexes = set() + for batch_data_rank in range(distributed_config.batch_data_parallel): + batches_tokens = torch.load(path / f"rank_{batch_data_rank}_0.pt") + Assert.eq(batches_tokens.shape, (_NUM_BATCHES, batch_config.micro_batch_size)) + for model_and_sequence_data_rank in range( + 1, distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).size + ): + Assert.all_equal( + torch.load(path / f"rank_{batch_data_rank}_{model_and_sequence_data_rank}.pt"), batches_tokens + ) + sample_indexes.update(batches_tokens.flatten().tolist()) + Assert.eq(len(sample_indexes), _NUM_BATCHES * batch_config.batch_size) + + +def _run_test_data_streaming_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, port: int +) -> None: + # Import all dynamic classes. TODO: needed? + import fast_llm.cli # noqa + + print(_DISTRIBUTED_TESTING_CONFIGS) + for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS: + with test_context.subtest(base_path, name, num_gpus) as subtest: + print(name, subtest.do_run) + if subtest.do_run: + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + _run_test_data_streaming(base_path / name, distributed_config, batch_config, port) + + +def test_data_streaming(result_path, worker_resources): + distributed_config, batch_config = _get_distributed_and_batch_config({}) + path = result_path / "data_streaming/single_gpu" + _run_test_data_streaming(path, distributed_config, batch_config, worker_resources.torchrun_port) + check_data_streaming_results(path, distributed_config, batch_config) + + +_DISTRIBUTED_TESTING_CONFIGS = [ + ("dp2", 2, {}), + ("sdp2", 2, {"sequence_data_parallel": 2}), + ("tp2", 2, {"tensor_parallel": 2}), + ("pp2", 2, {"pipeline_parallel": 2}), + ("dp2_sdp2", 4, {"sequence_data_parallel": 2}), + ("dp2_tp2", 4, {"tensor_parallel": 2}), + ("dp2_pp2", 4, {"pipeline_parallel": 2}), + ("sdp2_tp2", 4, {"sequence_data_parallel": 2, "tensor_parallel": 2}), + ("sdp2_pp2", 4, {"sequence_data_parallel": 2, "pipeline_parallel": 2}), + ("tp2_pp2", 4, {"tensor_parallel": 2, "pipeline_parallel": 2}), +] + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_data_streaming"]) +def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): + run_parallel_script( + _run_test_data_streaming_distributed, + (result_path / "data_streaming", worker_resources.torchrun_port), + world_size=4, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. + ) + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_data_streaming"]) +@pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) +def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): + report_subtest(path := result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ee3e0e2e1..e3d5094ea 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,8 +9,10 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.loss.config import LanguageModelLossKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert +from tests.layers.test_lm_losses import reference_grpo_loss from tests.utils.utils import get_base_model, get_stage SEQUENCE_LENGTH = 200 @@ -25,6 +27,7 @@ class LMHeadTestConfig: label_loss: bool | float = False distillation_loss: bool | float = False z_loss: bool | float = False + grpo_loss: bool | float = False logits_scale_factor: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False @@ -38,7 +41,10 @@ class LMHeadTestConfig: def actual_label_loss(self): return ( True - if self.label_loss is False and self.distillation_loss is False and self.z_loss is False + if self.label_loss is False + and self.distillation_loss is False + and self.z_loss is False + and self.grpo_loss is False else self.label_loss ) @@ -61,6 +67,10 @@ def get_config(self) -> GPTModelConfig: losses["z_loss"] = {"type": "z_loss"} if isinstance(self.z_loss, float): losses["z_loss"]["weight"] = self.z_loss + if self.grpo_loss is not False: + losses["grpo_loss"] = {"type": "grpo"} + if isinstance(self.grpo_loss, float): + losses["grpo_loss"]["weight"] = self.grpo_loss if losses: head_config["losses"] = losses @@ -108,7 +118,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: } if self.loss_masking: kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) - if self.actual_label_loss is not False: + if self.actual_label_loss is not False or self.grpo_loss is not False: labels = torch.randint( 0, VOCAB_SIZE, @@ -127,6 +137,19 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: dtype=input_.dtype, device=device, ) + + if self.grpo_loss is not False: + kwargs[LanguageModelLossKwargs.advantages] = torch.randn( + input_.shape[:-1], + dtype=torch.float32, + device=device, + ) + kwargs[LanguageModelLossKwargs.old_log_probabilities] = torch.randn( + input_.shape[:-1], + dtype=torch.float32, + device=device, + ) + return input_, kwargs def get_reference_outputs( @@ -152,7 +175,7 @@ def get_reference_outputs( total_loss = 0 losses = {} - if self.actual_label_loss is not False: + if self.actual_label_loss is not False or self.grpo_loss is not False: if self.sequence_first: labels = kwargs[LanguageModelKwargs.labels][ head._prediction_distance : head._prediction_distance + logits.size(0) @@ -161,6 +184,7 @@ def get_reference_outputs( labels = kwargs[LanguageModelKwargs.labels][ :, head._prediction_distance : head._prediction_distance + logits.size(1) ] + if self.actual_label_loss is not False: label_loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), labels.flatten(), reduction="none" ).mean() @@ -187,6 +211,16 @@ def get_reference_outputs( losses["z_loss"] = z_loss.detach() total_loss = total_loss + float(self.z_loss) * z_loss + if self.grpo_loss is not False: + grpo_loss = reference_grpo_loss( + logits, + labels, + kwargs[LanguageModelLossKwargs.advantages], + kwargs[LanguageModelLossKwargs.old_log_probabilities], + ) + losses["grpo_loss"] = grpo_loss.detach() + total_loss = total_loss + float(self.grpo_loss) * grpo_loss + total_loss.backward() if len(losses) > 1: @@ -227,6 +261,7 @@ def _add_configs(base_name: str, **kwargs): _add_configs("label_loss", label_loss=True) _add_configs("distillation_loss", distillation_loss=True) _add_configs("z_loss", z_loss=True) +_add_configs("grpo_loss", grpo_loss=True) _add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True) _add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5) _add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 639a3ba7c..c6567db47 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -13,6 +13,7 @@ from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import grpo_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward from fast_llm.utils import Assert @@ -46,6 +47,19 @@ def _get_lm_loss_inputs( return logits, target, loss_mask +def _get_grpo_loss_inputs(num_columns: int, loss_masking: bool, batch_shape: tuple[int], dtype): + logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.labels, batch_shape, dtype) + advantages = torch.randn_like(target, dtype=torch.float32) + # We want some correlation between the old and new log probabilities for the test to be meaningful. + old_log_probabilities = ( + torch.nn.functional.log_softmax(logits, dim=-1) + .gather(dim=-1, index=(target * (target >= 0) if loss_masking else target).unsqueeze(-1)) + .squeeze(-1) + + torch.randn_like(target, dtype=torch.float32) / 2 + ) + return logits, target, advantages, old_log_probabilities + + def _compare_losses_and_grads( loss: torch.Tensor, ref_loss: torch.Tensor, @@ -107,6 +121,33 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() +def reference_grpo_loss( + logits: torch.Tensor, + labels: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + logits_ = logits.float() + + # Log probabilities. + loss_mask = labels >= 0 + labels = labels * loss_mask + target_log_probabilities = ( + torch.nn.functional.log_softmax(logits_ * logits_scale_factor, dim=-1) + .gather(dim=-1, index=labels.unsqueeze(-1)) + .squeeze(-1) + ) + probability_ratio = torch.exp(target_log_probabilities - old_log_probabilities) + loss = -torch.min( + probability_ratio * advantages, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, + ) + return (loss * loss_mask).mean() + + _BATCH_SHAPES = ((64,), (16, 8)) _LOSS_PARAMETERS = ( (500, 1.0, 1.0, False, DataType.float32), # Simple @@ -187,6 +228,31 @@ def _test_entropy_loss( _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) +def _test_grpo_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): + logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( + num_columns, loss_masking, batch_shape, dtype + ) + out_ref, grad_ref = loss_forward_backward( + grad_output, + reference_grpo_loss, + logits, + target, + advantages, + old_log_probabilities, + logits_scale_factor=logits_scale_factor, + ) + out_fused, grad_fused = grpo_loss_forward_backward( + split_op(logits, group, -1), + target, + advantages, + old_log_probabilities, + grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + ) + _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + + def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) out_ref, grad_ref = loss_forward_backward( @@ -237,6 +303,15 @@ def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +def test_grpo_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype): + _test_grpo_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) + + @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) @@ -287,6 +362,19 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa dtype, test_context.group, ) + # GRPO + with test_context.subtest(base_path, f"grpo-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_grpo_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + test_context.group, + ) @pytest.mark.slow @@ -325,6 +413,7 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): if target_format != TargetFormat.labels or entropy_loss_type != EntropyLossType.reverse_kl ), "z_loss", + "grpo", ), ) def test_lm_loss_distributed( diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py new file mode 100644 index 000000000..f132b465d --- /dev/null +++ b/tests/models/test_streaming.py @@ -0,0 +1,213 @@ +import contextlib +import dataclasses +import functools +import json +import logging +import pathlib + +import pytest +import safetensors +import torch + +from fast_llm.engine.training.config import StreamingTrainerCallbackConfig +from fast_llm.engine.training.streaming import REDIS_TRAINING_FIELD, REDIS_TRAINING_STREAM +from fast_llm.utils import Assert +from tests.conftest import WorkerResources +from tests.models.test_checkpoint import compare_safetensor_files +from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup, update_and_add_testing_config +from tests.utils.redis import redis_batch_producer +from tests.utils.run_test_script import do_run_test_script_for_all_models +from tests.utils.subtest import DistributedTestContext +from tests.utils.utils import requires_cuda + + +@dataclasses.dataclass(kw_only=True) +class StreamingDistributedTestingConfig(DistributedTestingConfig): + consumer_count: int = (1,) + + @functools.cached_property + def total_gpus(self) -> int: + return self.num_gpus + self.consumer_count + + +_DISTRIBUTED_STREAMING_CONFIGS = [ + StreamingDistributedTestingConfig(name="streaming_simple", config_args=[], num_gpus=1, consumer_count=1), + StreamingDistributedTestingConfig(name="streaming_dp2", config_args=[], num_gpus=2, consumer_count=1), + StreamingDistributedTestingConfig( + name="streaming_sdp2_c2", + config_args=["model.distributed.sequence_data_parallel=2"], + num_gpus=2, + consumer_count=2, + ), + StreamingDistributedTestingConfig( + name="streaming_tp2", config_args=["model.distributed.tensor_parallel=2"], num_gpus=2, consumer_count=2 + ), + StreamingDistributedTestingConfig( + name="streaming_stp2_c2", + config_args=[ + "model.distributed.tensor_parallel=2", + "model.distributed.sequence_tensor_parallel=true", + "callbacks.streaming.broadcast.external_world_size=2", + ], + num_gpus=2, + consumer_count=2, + ), +] + + +def _run_event_consumer( + streaming_config: StreamingTrainerCallbackConfig, consumer_index: int, base_path: pathlib.Path +) -> None: + client = streaming_config.get_client() + init_method = f"tcp://{streaming_config.broadcast.host}:{streaming_config.broadcast.port}" + logging.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + path = base_path / "streaming" + path.mkdir(parents=True, exist_ok=True) + field = REDIS_TRAINING_FIELD.encode() + # TODO: Create a custom process group instead. + try: + process_group = torch.distributed.init_process_group( + backend="nccl", + init_method=init_method, + world_size=streaming_config.broadcast.external_world_size + 1, + rank=consumer_index + 1, + ) + last_id = "0-0" + while True: + result = client.xread( + streams={REDIS_TRAINING_STREAM: last_id}, + count=1, + block=10000, + ) + if not result: + raise TimeoutError("No message received after 10000 ms...") + + ((stream, events),) = result + Assert.eq(stream.decode(), REDIS_TRAINING_STREAM) + Assert.eq(len(events), 1) + for last_id, message in events: + Assert.eq(message.keys(), {field}) + message = json.loads(message[field].decode()) + logging.info(f"Received: {message}") + if message["type"] == "training_finished": + return + elif message["type"] == "weights_ready": + weights = {} + while True: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=process_group, group_src=0) + if meta[0] is None: + print(f"Weight broadcast finished") + break + logging.info(f"receiving {meta[0]}") + shard_name, layer_name, tensor_size, tensor_type = meta[0] + tensor = torch.zeros(tuple(tensor_size), dtype=tensor_type, device="cuda") + torch.distributed.broadcast(tensor, group=process_group, group_src=0) + if shard_name == "weights": + weights[layer_name] = tensor + safetensors.torch.save_file( + weights, path / f"rank_{consumer_index}_step_{message["step"]}.safetensors" + ) + + finally: + torch.distributed.destroy_process_group() + + +def _run_model_streaming_configs( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig, port: int +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for config in _DISTRIBUTED_STREAMING_CONFIGS: + model_testing_config = update_and_add_testing_config( + model_testing_config, + None, + updates={ + ("data", "datasets"): {"training": {"port": port}}, + ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, + "callbacks": { + "streaming": { + "type": "streaming", + "port": port, + "broadcast": { + "port": port + 1000, + "external_world_size": config.consumer_count, + }, + "export": {"format": model_testing_config.checkpoint_format.name}, + } + }, + # Disable tensor logging. + ("run", "tensor_logs"): {}, + ("model", "multi_stage"): {}, + }, + groups={}, + ) + with test_context.subtest(base_path, config.name, config.total_gpus) as subtest: + if subtest.do_run: + if test_context.rank < config.num_gpus: + do_run_test_script_for_all_models(config, model_testing_config, base_path) + elif test_context.rank < config.total_gpus: + training_config = model_testing_config.trainer_config_class.from_dict( + model_testing_config.config_dict + ) + with ( + redis_batch_producer(training_config.callbacks["streaming"], training_config.batch) + if test_context.rank == config.num_gpus + else contextlib.nullcontext() + ): + _run_event_consumer( + training_config.callbacks["streaming"], + test_context.rank - config.num_gpus, + base_path / config.name, + ) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +def test_model_streaming(run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources): + # `test_run_model_distributed_streaming` and `test_model_distributed_streaming need a common dependency + # so they are placed in the same testing group and run in the same distributed process. + pass + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +def test_run_model_distributed_streaming( + run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources +): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_model_streaming_configs, + (run_test_script_base_path, model_testing_config, worker_resources.torchrun_port), + world_size=torch.cuda.device_count(), + backend=model_testing_config.distributed_backend, + ) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +@pytest.mark.parametrize("config", _DISTRIBUTED_STREAMING_CONFIGS) +def test_model_distributed_streaming( + config: StreamingDistributedTestingConfig, + run_distributed_script, + model_testing_config, + run_test_script_base_path, + worker_resources: WorkerResources, + report_subtest, +): + report_subtest(path := run_test_script_base_path / config.name, config.total_gpus) + compare_safetensor_files( + path / "export" / model_testing_config.checkpoint_format.name / f"1/model_0.safetensors", + *( + path / "streaming" / f"rank_{consumer_index}_step_1.safetensors" + for consumer_index in range(config.consumer_count) + ), + ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 4ad122947..6456f1589 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -129,6 +129,7 @@ def _get_hf_test_dataset( min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, + has_grpo_data: bool = False, min_images: int = 0, max_images: int = 0, min_image_size: int = 4, @@ -153,6 +154,9 @@ def _get_hf_test_dataset( document_sizes, min_images, max_images, min_image_size, max_image_size, random_state ) + if has_grpo_data: + dataset_dict["advantages"] = random_state.randn(num_documents).tolist() + return datasets.Dataset.from_dict(dataset_dict) @@ -168,6 +172,7 @@ def _get_test_dataset( min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, + has_grpo_data: bool = False, splits: dict[str, float] | None = None, min_images: int = 0, max_images: int = 0, @@ -192,6 +197,7 @@ def _get_test_dataset( min_loss_masking_spans=min_loss_masking_spans, max_loss_masking_spans=max_loss_masking_spans, has_preference_spans=has_preference_spans, + has_grpo_data=has_grpo_data, min_images=min_images, max_images=max_images, min_image_size=min_image_size, @@ -207,6 +213,8 @@ def _get_test_dataset( if max_images > 0: source_schema["images"] = "images" source_schema["image_positions"] = "image_positions" + if has_grpo_data: + source_schema["advantages"] = "advantages" download_santacoder_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( @@ -239,6 +247,7 @@ def _get_test_dataset( vocab_size=max_vocab_size, use_loss_masking_spans=max_loss_masking_spans > 0, use_preference_spans=has_preference_spans, + use_grpo_data=has_grpo_data, ) return path, config, hf_path, preprocessing @@ -324,6 +333,7 @@ def get_model_test_dataset(config_only: bool = False): seed=1234, num_documents=200, max_loss_masking_spans=5, + has_grpo_data=True, max_vocab_size=MODEL_TEST_VOCAB_SIZE, splits={"training": 180, "validation": 19, "test": 1}, config_only=config_only, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5e7526377..8c5be2979 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -53,6 +53,7 @@ class ModelTestingGroup(enum.StrEnum): generate = "generate" megatron = "megatron" distributed = "distributed" + streaming = "streaming" class ModelTestingGroupAction(enum.StrEnum): @@ -391,6 +392,7 @@ def update_and_add_testing_config( ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.streaming: ModelTestingGroupAction.normal, }, ) @@ -733,6 +735,25 @@ def update_and_add_testing_config( auto_model_class=transformers.AutoModelForImageTextToText, ) +update_and_add_testing_config( + # Tests mixture of experts, mixtral converter. + "llama", + "llama_grpo", + updates={ + ("model", "base_model", "head", "losses"): {"grpo": {"type": "grpo"}}, + ("batch", "use_grpo_data"): True, + }, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.streaming: ModelTestingGroupAction.normal, + }, +) + update_and_add_testing_config( # Tests hybrid with attention + gated delta net mixer. diff --git a/tests/utils/redis.py b/tests/utils/redis.py new file mode 100644 index 000000000..198c6df78 --- /dev/null +++ b/tests/utils/redis.py @@ -0,0 +1,136 @@ +import contextlib +import itertools +import pathlib +import socket +import threading +import time + +import fakeredis + +from fast_llm.data.dataset.config import ( + REDIS_DATA_STREAM, + REDIS_GROUP_NAME, + RedisConfig, + SamplingConfig, + SamplingData, + SamplingParameters, + StreamingDatasetConfig, +) +from fast_llm.data.dataset.streaming import RedisDocument +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.models.gpt.config import GPTBatchConfig + + +def find_free_port(): + """Find a free TCP port and return it.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def wait_until_stream_empty( + redis_client, + stream_key, + consumer_group, + stop_event, +): + """ + Wait until lag == 0, meaning all messages have been delivered AND acknowledged. + Absence of group mean test has not started yet, so we wait + """ + consumer_group = consumer_group.encode() + while not stop_event.is_set(): + groups = redis_client.xinfo_groups(stream_key) + + g = next((g for g in groups if g["name"] == consumer_group), None) + if g is not None: + lag = g.get("lag", 0) + if lag == 0: + return + + time.sleep(0.05) + + +def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig): + while not stop_event.is_set(): + res = redis_client.hget(f"{REDIS_DATA_STREAM}:consumer_count", "0") + if res is None: + time.sleep(0.05) + continue + return int(res) + + +@contextlib.contextmanager +def redis_batch_producer(config: RedisConfig, batch_config: GPTBatchConfig): + with fake_redis_server(config): + stop_event = threading.Event() + client = config.get_client() + + def producer_loop(): + for sample_index in itertools.count(): + if stop_event.is_set(): + break + client.xadd( + REDIS_DATA_STREAM, + RedisDocument.from_dict({"tokens": [sample_index] * batch_config.sequence_length}).to_message(), + ) + if sample_index % 5 == 0: + wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event) + + thread = threading.Thread(target=producer_loop, daemon=True) + thread.start() + + try: + yield + finally: + stop_event.set() + thread.join(timeout=1) + client.close() + + +def make_sampling(sequence_length, num_samples, distributed): + return SamplingData( + parameters=SamplingParameters( + sequence_length=sequence_length, + extra_tokens=0, + num_samples=num_samples, + truncate_documents=False, + ), + config=SamplingConfig(), + distributed=distributed, + dataset_name="test", + cache_directory=pathlib.Path("/tmp"), + preprocessing=LanguageModelPreprocessingConfig(), + ) + + +@contextlib.contextmanager +def fake_redis_server(config: RedisConfig): + # We search for free port as port from previous test can still be not free even after server shutdown + + # ----- Monkey-patch handler to suppress broken pipes ----- + orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle + + def safe_handle(self): + try: + orig_handle(self) + except (ConnectionResetError, BrokenPipeError): + # Client disconnected abruptly (e.g., when a PyTorch DataLoader iterator is deleted). + # These errors occur only with fake Redis and can be safely ignored. + pass + except Exception as e: + print(f"Unexpected exception in fake Redis handler: {e}") + + fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + try: + yield + finally: + # ----- Teardown ----- + server.shutdown() + server.server_close() + thread.join()