How to Fix RuntimeError: expected scalar type double but found float

To fix the RuntimeError: expected scalar type double but found float error, you can convert the input tensor’s data type to match the model’s expected data type using the double() or to() method.

The RuntimeError: expected scalar type double but found float error occurs when there is a mismatch between the data types of the input tensor and the model’s weights. The model expects a double-precision floating-point tensor (torch.DoubleTensor), but the input tensor is a single-precision floating-point (torch.FloatTensor).

Here’s a code example that will reproduce the error.

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)

# Create the model and convert it to double
model = SimpleModel().double()

# Create an input tensor with the data type torch.float (single-precision)
input_tensor = torch.randn(5, 10, dtype=torch.float)

# Perform a forward pass, which will result in an error due to data type mismatch
output = model(input_tensor)

Output

RuntimeError: expected scalar type Double but found Float

How to fix it?

Solution 1: Convert the input tensor to double using the double() function

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)

model = SimpleModel().double()
input_tensor = torch.randn(5, 10, dtype=torch.float)

# Convert input tensor to double
input_tensor = input_tensor.double()

output = model(input_tensor)
print(output)

Output

tensor([[ 0.7394],
       [-0.0233],
       [-0.6937],
       [ 0.8246],
       [ 0.3691]], dtype=torch.float64, grad_fn=<AddmmBackward0>)

Solution 2: Convert the model to float

import torch
import torch.nn as nn


class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)


# Convert the model to float
model = SimpleModel().to(torch.float)
input_tensor = torch.randn(5, 10, dtype=torch.float)

output = model(input_tensor)
print(output)

Output

tensor([[ 0.0163],
        [-0.2952],
        [-0.0456],
        [-0.0267],
        [ 0.9060]], grad_fn=<AddmmBackward0>)

After converting the model’s weights data type, you should be able to perform the forward pass without encountering the error.

Remember that converting the model to a single-precision floating-point may result in slightly reduced accuracy but consume less memory and computation resources.

Both solutions fixed the error by ensuring that the data types of the input tensor and the model’s weights matched. Choose one of these options based on your precision and memory usage requirements.

That’s it.

Leave a Comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.