Knowing the device (e.g. CPU, GPU) on which a PyTorch model is located is useful for several reasons, such as hardware resource management, performance optimization, compatibility and portability, resource allocation and scaling. This tutorial explains how to get the device on which a PyTorch model is currently located.
Code
Most of the time, all the parameters of the model are located on the same device. So, we could simply use the first parameter to determine the device of the entire model.
In the following code, we define a simple model which consists of a single linear layer. Using the parameters
method, we retrieve the iterator that provides access to all the parameters of the model. Next, we get the first parameter from the iterator. Finally, we get device information of the parameter, which printed to the console.
import torch.nn as nn
model = nn.Sequential(
nn.Linear(1, 1)
)
device = next(model.parameters()).device
print(device)
The above code will print either cpu
or cuda:0
depending on the device the model is currently located on.
Leave a Comment
Cancel reply