Managing your data in Pytorch
Hey there! In this blog post, I will write regarding essential notes in the implementation of deep learning models, which you can employ in your projects particularly those implemented using pytorch. If you have ever implemented contrastive learning models, you would have encountered memory scarcity errors because of the large batch sizes in these methods. If you have limited resources or the efficiency of your code is important for you then these notes can prove to be exceedingly advantageous.
Workers
Workers play the main role in the training neural networks, they are used to load data from disk to memory and then feed the model. Workers work parallel and when one of them loads the data, places it in a queue, and then model trains with the batch that becomes available in the queue. The number of workers is important and can affect your training procedure profoundly. In the rest of the blog, I will mention notes on how to choose worker size.
CPU
You may have used transforms
in pytorch to augment your data, in essence, workers apply this augmentation after loading data and just before feeding it to the model thereby workers are CPU-bound and limited by the speed of the CPU. If you use a large number of workers your dataset is large and complex or you have a large batch size ensure to have a powerful CPU otherwise will bottleneck your training process.
RAM
Using a large number of workers, moreover, can significantly increase the usage of RAM, particularly in tasks like contrastive learning in which batch size is large. This is because each worker process will load a batch of data into memory before sending it to the main process. For example, if you are using 4 workers and a batch size of 128, each worker process will load 128 samples into memory before sending them to the main process. This means that the main process will need to have enough memory to store 4 batches of data.
Conclusion
Although you may suppose a large number of workers can significantly improve the performance of your training loop, not only will the main process have to spend more time communicating with the worker processes, but also it requires powerful CPU and RAM. A good rule of thumb is to use a number of workers that is equal to the number of CPU cores on your machine. This will ensure that the worker processes can load data from the disk as quickly as possible without overloading the main process but still, you should consider the notes mentioned before.
Data Augmentation
As I mentioned before workers apply augmentations on data after the loading thereby these augmentations consume RAM and CPU, so If you are using a large number of augmentations or if your dataset is large or complex, you may encounter high RAM usage. One feasible way is to Apply augmentations before the training process. For this you can save the augmented data on disk and load it into memory during training, which can significantly reduce RAM usage, especially if you are using a large number of augmentations or if your dataset is large or complex. I prepared a piece of code that you can use in your implementations:
Regular data loading
# Download and load STL10 dataset with on-the-fly augmentation
train_dataset = STL10(root=root_path, split='unlabeled', download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
get_ram_usage()
# Example training loop
for epoch in range(1):
for data, labels in train_loader:
get_ram_usage()
print(data.size())
# Your training LOOP
pass
Results
RAM usage: 3046.1875 MB
RAM usage: 3333.125 MB
torch.Size([1024, 3, 96, 96])
RAM usage: 3441.2578125 MB
torch.Size([1024, 3, 96, 96])
RAM usage: 3441.30078125 MB
torch.Size([1024, 3, 96, 96])
RAM usage: 3441.3203125 MB
torch.Size([1024, 3, 96, 96])
RAM usage: 3441.33984375 MB
torch.Size([1024, 3, 96, 96])
RAM usage: 3441.359375 MB
torch.Size([1024, 3, 96, 96])
Suggested data loading
1- Save Augmented Data
# Create a directory to save augmented batches
save_dir = 'augmented_batches'
os.makedirs(save_dir, exist_ok=True)
# Generate and save augmented batches
num_augmentations = 1
total_batches = len(train_loader)
for aug_idx in range(num_augmentations):
for batch_idx, augmented_batch in enumerate(train_loader):
# Save the augmented batch
save_path = os.path.join(save_dir, f'batch_{batch_idx}_aug_{aug_idx}.pt')
torch.save(augmented_batch, save_path)
print(f'Saved batch {batch_idx + 1}/{total_batches}')
print("Augmented batches saved successfully.")
2- Restart Environment to clear out the RAM
3- Load Saved Augmented data
import os
import torch
from torch.utils.data import DataLoader, Dataset
import psutil
def get_ram_usage():
process = psutil.Process() # If using GPU, else use psutil.Process()
ram_usage = process.memory_info().rss / float(2 ** 20) # RAM usage in MB
print("RAM usage:", ram_usage, "MB")
return ram_usage
class CustomDataset(Dataset):
def __init__(self, data_folder, transform=None):
self.data_folder = data_folder
self.transform = transform
self.file_list = os.listdir(data_folder)
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
file_path = os.path.join(self.data_folder, self.file_list[idx])
print(file_path)
data = torch.load(file_path)
return data
# Create a custom dataset using your saved data
custom_dataset = CustomDataset("./augmented_batches")
train_loader = DataLoader(custom_dataset, batch_size=1, shuffle=True)
get_ram_usage()
# Example training loop
for epoch in range(1):
for data, labels in train_loader:
get_ram_usage()
print(data.size())
pass
Results
RAM usage: 408.6875 MB
./augmented_batches/batch_83_aug_0.pt
RAM usage: 520.82421875 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_97_aug_0.pt
RAM usage: 483.703125 MB
torch.Size([1, 672, 3, 96, 96])
./augmented_batches/batch_82_aug_0.pt
RAM usage: 591.7109375 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_65_aug_0.pt
RAM usage: 628.8671875 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_74_aug_0.pt
RAM usage: 628.87109375 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_96_aug_0.pt
RAM usage: 617.859375 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_35_aug_0.pt
RAM usage: 628.87109375 MB
torch.Size([1, 1024, 3, 96, 96])
./augmented_batches/batch_12_aug_0.pt
RAM usage: 628.87109375 MB
MIXED PRECISION
Most of the time models use 32-bit floating-point numbers for parameters and activations, which although provide high precision, require more memory and computational resources. Fortunately, there is a technique that you can reduce this usage, which is called the Mixed Precision Technique. It can greatly enhance the efficiency of your code by using a combination of 32 and 16-bit floating point numbers because in some parts of training, we do not need float32 numbers such as forward and backward passes. Not only does it reduce memory usage, but it can also be beneficial for computational efficiency. However, you should pay much attention to precision loss, especially numerical stability during certain operations.
Certain operations, such as linear layers and convolutions, exhibit enhanced efficiency when executed in float16 or bfloat16. Conversely, reductions often necessitate the dynamic range provided by float32. In pytorch, torch.cuda.amp
offers some methods for utilizing mixed precision in which the objective is to appropriately match each operation with its suitable data type, thereby reducing both the runtime and memory usage of the network.
Basic Components of AMP :
torch.autocast
: The autocast context is used to specify the region of code where lower-precision operations should be performed. During these sections, CUDA ops are executed in a data type selected by autocast to enhance efficiency while preserving accuracy.
for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
# Runs the forward pass under ``autocast``.
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
# output is float16 because linear layers ``autocast`` to float16.
assert output.dtype is torch.float16
loss = loss_fn(output, target)
# loss is float32 because ``mse_loss`` layers ``autocast`` to float32.
assert loss.dtype is torch.float32
# Exits ``autocast`` before backward().
# Backward passes under ``autocast`` are not recommended.
# Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
loss.backward()
opt.step()
opt.zero_grad() # set_to_none=True here can modestly improve performance
[torch.cuda.amp.GradScaler]
: To counter precision loss during backward passes, AMP includes the concept of gradient scaling. This involves multiplying the loss by a scaling factor before performing the backward pass. This scaling factor is then used to adjust the gradients during optimization.
# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.
# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.
# The same ``GradScaler`` instance should be used for the entire convergence run.
# If you perform multiple convergence runs in the same script, each run should use
# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.
scaler = torch.cuda.amp.GradScaler()
for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
loss = loss_fn(output, target)
# Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
# If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(opt)
# Updates the scale for next iteration.
scaler.update()
opt.zero_grad() # set_to_none=True here can modestly improve performance
** These Code implementations are from pytorch documentation. You can check it for more details:
Automatic Mixed Precision — PyTorch Tutorials 2.1.0+cu121 documentation
Last words
I hope you found these notes helpful, remember, the journey of refining your code for efficiency is a continuous process. Feel free to share your thoughts, comments, or any corrections. Your feedback is invaluable :) Happy coding!