-
Notifications
You must be signed in to change notification settings - Fork 3k
Description
Describe the bug
Data duplication in different rank, when process a iterabledataset with first split_dataset_by_node and then interleaved_dataset
Steps to reproduce the bug
I have provide a minimum scripts
import os
from datasets import interleave_datasets, load_dataset
from datasets.distributed import split_dataset_by_node
path = "/mnt/wwx/datasets/fineweb/data/CC-MAIN-2013-20/"
files = [os.path.join(path, fn) for fn in os.listdir(path)]
dataset = load_dataset("parquet", split="train", data_files=files, streaming=True)
print(f"{dataset.n_shards=}")
dataset_rank0 = split_dataset_by_node(dataset, 0, 4)
dataset_rank1 = split_dataset_by_node(dataset, 1, 4)
dataset_rank0_interleaved = interleave_datasets([dataset_rank0], seed=42, probabilities=[1.0])
dataset_rank1_interleaved = interleave_datasets([dataset_rank1], seed=42, probabilities=[1.0])
print("print the first sample id from all datasets")
print("dataset", next(iter(dataset))['id'])
print("dataset_rank0", next(iter(dataset_rank0))['id'])
print("dataset_rank1", next(iter(dataset_rank1))['id'])
print("dataset_rank0_interleaved", next(iter(dataset_rank0_interleaved))['id'])
print("dataset_rank1_interleaved", next(iter(dataset_rank1_interleaved))['id'])
dataset_rank0_shard = dataset.shard(4, 0)
dataset_rank1_shard = dataset.shard(4, 1)
dataset_rank0_shard_interleaved = interleave_datasets([dataset_rank0_shard], seed=42, probabilities=[1.0])
dataset_rank1_shard_interleaved = interleave_datasets([dataset_rank1_shard], seed=42, probabilities=[1.0])
print("dataset_rank0_shard", next(iter(dataset_rank0_shard))['id'])
print("dataset_rank1_shard", next(iter(dataset_rank1_shard))['id'])
print("dataset_rank0_shard_interleaved", next(iter(dataset_rank0_shard_interleaved))['id'])
print("dataset_rank1_shard_interleaved", next(iter(dataset_rank1_shard_interleaved))['id'])I just use a subfold of C4 with 14 paruets to do the quick run and get
dataset.n_shards=14
print the first sample id from all datasets
dataset <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank0 <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank1 <urn:uuid:6b7da64f-c26e-4086-aef5-4b6f01106223>
dataset_rank0_interleaved <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank1_interleaved <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank0_shard <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank1_shard <urn:uuid:67cf7216-dd05-4f55-a28a-1a1c96989c51>
dataset_rank0_shard_interleaved <urn:uuid:c84a7f00-f3e8-4b67-baa4-df5adaf23bae>
dataset_rank1_shard_interleaved <urn:uuid:67cf7216-dd05-4f55-a28a-1a1c96989c51>
Expected behavior
the first sample of dataset_rank0_interleaved and dataset_rank1_interleaved should be different, as other rank0 rank1 couples.
I have dive into the function and try to find how it work in split -> interleaved process.
the split_dataset_by_node of iterable dataset does't not change ._ex_iterable attribute of the dataset. it just set the distributed config in dataset, and the distributed dataset is used in actually __iter__ call, to handle with shard split or sample skipping.
however, in interleaved_dataset of iterable dataset. it copy out all of the ._ex_iterable of provided datasets, and consist a new _ex_iterable, so the missing copy of distributed config caused the data duplication in different dp rank.
So I may first ask, is it an unexpected using order of those function, which means:
- always do
split_dataset_by_nodeat final rather than in middle way. - or use
dataset.shard(dp_size, dp_rank)rather thansplit_dataset_by_nodein case similar of mine.
if the using order is permiited, I think it is a bug, and I can do a PR to fix it
(I meet this bug in real training, related issue is ByteDance-Seed/VeOmni#200 if it helps.
Environment info
datasets 4.4.1
ubuntu 20.04
python 3.11.4