How to Fix UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead

Python raises UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool insteadwhen you use a Byte tensor (dtype=torch.uint8) for indexing or masking in PyTorch.

To fix this UserWarning, you should convert the Byte tensor to a Bool tensor (dtype=torch.bool) before indexing or masking.

Reproduce the error

import torch

float_tensor = torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.8, 0.1]], dtype=torch.float)
mask = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.uint8) # Byte tensor (uint8)

# This will raise an error due to data type mismatch
masked_tensor = float_tensor[mask]

Output

UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

How to fix it

To fix the error, you can convert the Byte tensor to a Bool tensor.

import torch

float_tensor = torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.8, 0.1]], dtype=torch.float)
mask = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.uint8) # Byte tensor (uint8)

mask = mask.to(torch.bool)
masked_tensor = float_tensor[mask]
print(masked_tensor)

Output

tensor([0.5000, 0.7000, 0.8000])

By converting the Byte tensor to a Bool tensor, you can perform indexing or masking operations without encountering the warning.

Using Bool tensors for indexing or masking is the recommended approach in recent versions of PyTorch.

That’s it.

Leave a Comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.