Python raises UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool insteadwhen you use a Byte tensor (dtype=torch.uint8) for indexing or masking in PyTorch.
To fix this UserWarning, you should convert the Byte tensor to a Bool tensor (dtype=torch.bool) before indexing or masking.
Reproduce the error
import torch
float_tensor = torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.8, 0.1]], dtype=torch.float)
mask = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.uint8) # Byte tensor (uint8)
# This will raise an error due to data type mismatch
masked_tensor = float_tensor[mask]
Output
UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
How to fix it
To fix the error, you can convert the Byte tensor to a Bool tensor.
import torch
float_tensor = torch.tensor([[0.5, 0.3, 0.7], [0.2, 0.8, 0.1]], dtype=torch.float)
mask = torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.uint8) # Byte tensor (uint8)
mask = mask.to(torch.bool)
masked_tensor = float_tensor[mask]
print(masked_tensor)
Output
tensor([0.5000, 0.7000, 0.8000])
By converting the Byte tensor to a Bool tensor, you can perform indexing or masking operations without encountering the warning.
Using Bool tensors for indexing or masking is the recommended approach in recent versions of PyTorch.
That’s it.