Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/scratch/datasets/causal_langauge_modeling_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ def load_hf_dataset(
)

if shuffle:
data = data.shuffle().with_format("torch")
data = data.shuffle()

if validate:
data = data.filter(validate).with_format("torch")
data = data.filter(validate)

def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)

data = data.map(tokenize_function, batched=True).with_format("torch")
data = data.map(tokenize_function, batched=True)

if prepare:
data = data.map(prepare).with_format("torch")
data = data.map(prepare)

return data.with_format("torch")

Expand Down Expand Up @@ -162,9 +162,9 @@ def transform(batch: CausalLanguageModelingBatch):
batch["attention_mask"],
batch["labels"],
)
input_ids = torch.tensor(input_ids, dtype=torch.int64)
attention_mask = torch.tensor(attention_mask, dtype=torch.int64)
labels = torch.tensor(labels, dtype=torch.int64)
input_ids = torch.as_tensor(input_ids, dtype=torch.int64)
attention_mask = torch.as_tensor(attention_mask, dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64)
return CausalLanguageModelingBatch(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
Expand Down Expand Up @@ -196,12 +196,12 @@ def wikitext2_dataset(
tokenizer = load_tokenizer(tokenizer_name, max_length=max_length)

def prepare(sample):
input_ids = sample["input_ids"]
input_ids = torch.tensor(input_ids, dtype=torch.int64)
labels = input_ids.clone()
input_ids = np.array(sample["input_ids"], dtype=np.int64)
labels = input_ids.copy()
# Make a lower triangular attention mask
attention_mask = np.tril(np.ones((len(input_ids), len(input_ids))))
attention_mask = torch.tensor(attention_mask, dtype=torch.int64)
attention_mask = np.tril(
np.ones((len(input_ids), len(input_ids)), dtype=np.int64)
)
sample["input_ids"], sample["attention_mask"], sample["labels"] = (
input_ids,
attention_mask,
Expand Down
52 changes: 27 additions & 25 deletions src/scratch/datasets/image_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,20 @@ def load_hf_dataset(
The IterableDataset object
"""
data = load_dataset(
dataset_name, split=dataset_split, trust_remote_code=True, streaming=True
).with_format("torch")
dataset_name,
split=dataset_split,
trust_remote_code=True,
streaming=True,
)

if shuffle:
data = data.shuffle().with_format("torch")
data = data.shuffle()

if validate:
data = data.filter(validate).with_format("torch")
data = data.filter(validate)

if prepare:
data = data.map(prepare).with_format("torch")
data = data.map(prepare)

return data.with_format("torch")

Expand Down Expand Up @@ -172,15 +175,13 @@ def mnist_dataset(batch_size=32, shuffle=True):

def prepare(sample):
images, labels = sample["image"], sample["label"]
# Ensure the images are float tensors
images = images.to(torch.float32)
# Normalize the images
images = images / 255.0
# Convert labels to one-hot encoding
labels = labels.to(torch.int64) # Ensure labels are int32 tensors
labels = F.one_hot(labels, num_classes=10).to(torch.int32)

sample["image"], sample["label"] = images, labels
images = transforms.ToTensor()(images).to(torch.float32)
labels = F.one_hot(
torch.as_tensor(labels, dtype=torch.int64),
num_classes=10,
).to(torch.int32)

sample["image"], sample["label"] = images.numpy(), labels.numpy()
return sample

train_data, test_data = (
Expand Down Expand Up @@ -219,20 +220,21 @@ def tiny_imagenet_dataset(batch_size=32, shuffle=True):

def prepare(sample):
images, labels = sample["image"], sample["label"]
# Ensure the images are float tensors
images = images.clone().detach().to(torch.float32)
# Normalize the images
images = images / 255.0
# Convert labels to one-hot encoding
labels = labels.clone().detach().to(torch.int64) # Ensure labels are int32
labels = F.one_hot(labels, num_classes=200).to(torch.int32)

sample["image"], sample["label"] = images, labels
images = transforms.ToTensor()(images).to(torch.float32)
labels = F.one_hot(
torch.as_tensor(labels, dtype=torch.int64),
num_classes=200,
).to(torch.int32)

sample["image"], sample["label"] = images.numpy(), labels.numpy()
return sample

def validate(sample):
transform = transforms.ToTensor()
img = transform(sample["image"])
img = (
sample["image"]
if isinstance(sample["image"], torch.Tensor)
else transforms.ToTensor()(sample["image"])
)
return (
img.shape == (3, 64, 64)
and torch.isnan(img).sum() == 0
Expand Down
8 changes: 4 additions & 4 deletions src/scratch/datasets/question_answering_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def transform(batch):
batch["start_positions"],
batch["end_positions"],
)
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
start_positions = torch.tensor(start_positions, dtype=torch.long)
end_positions = torch.tensor(end_positions, dtype=torch.long)
input_ids = torch.as_tensor(input_ids, dtype=torch.long)
attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
start_positions = torch.as_tensor(start_positions, dtype=torch.long)
end_positions = torch.as_tensor(end_positions, dtype=torch.long)
return QuestionAnsweringBatch(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
38 changes: 16 additions & 22 deletions src/scratch/datasets/sequence_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,25 @@ def load_hf_dataset(
The IterableDataset object
"""
data = load_dataset(
dataset_name, split=dataset_split, trust_remote_code=True, streaming=True
dataset_name,
split=dataset_split,
trust_remote_code=True,
streaming=True,
)

if shuffle:
data = data.shuffle().with_format("torch")
data = data.shuffle()

if validate:
data = data.filter(validate).with_format("torch")
data = data.filter(validate)

def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)

data = data.map(tokenize_function, batched=True).with_format("torch")
data = data.map(tokenize_function, batched=True)

if prepare:
data = data.map(prepare).with_format("torch")
data = data.map(prepare)

return data.with_format("torch")

Expand Down Expand Up @@ -157,8 +160,8 @@ def transform(batch: SequenceClassificationBatch):
batch["input_ids"],
batch["label"],
)
input_ids = torch.tensor(input_ids, dtype=torch.int64)
label = torch.tensor(label, dtype=torch.int64)
input_ids = torch.as_tensor(input_ids, dtype=torch.int64)
label = torch.as_tensor(label, dtype=torch.int64)
label = F.one_hot(label, num_classes=num_classes).to(torch.int32)
return SequenceClassificationBatch(
input_ids=input_ids,
Expand Down Expand Up @@ -192,21 +195,12 @@ def imdb_dataset(
tokenizer = load_tokenizer(tokenizer_name, max_length=max_length)

def prepare(sample):
input_ids, labels = (
sample["input_ids"],
sample["label"],
)
input_ids = torch.tensor(input_ids, dtype=torch.int64)
labels = torch.tensor(labels, dtype=torch.int64)
labels = F.one_hot(labels, num_classes=2).to(torch.int32)

(
sample["input_ids"],
sample["label"],
) = (
input_ids,
labels,
)
input_ids, labels = sample["input_ids"], sample["label"]
input_ids = np.array(input_ids, dtype=np.int64)
labels_tensor = torch.as_tensor(labels, dtype=torch.int64)
labels = F.one_hot(labels_tensor, num_classes=2).to(torch.int32).numpy()

sample["input_ids"], sample["label"] = input_ids, labels
return sample

train_data, test_data = (
Expand Down
6 changes: 3 additions & 3 deletions src/scratch/datasets/token_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def transform(batch):
batch["attention_mask"],
batch["labels"],
)
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)
input_ids = torch.as_tensor(input_ids, dtype=torch.long)
attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
labels = torch.as_tensor(labels, dtype=torch.long)
return TokenClassificationBatch(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
Expand Down
Loading