Get Device on Which PyTorch Model is Currently Located

Get Device on Which PyTorch Model is Currently Located

Knowing the device (e.g. CPU, GPU) on which a PyTorch model is located is useful for several reasons, such as hardware resource management, performance optimization, compatibility and portability, resource allocation and scaling. This tutorial explains how to get the device on which a PyTorch model is currently located.

Code

Most of the time, all the parameters of the model are located on the same device. So, we could simply use the first parameter to determine the device of the entire model.

In the following code, we define a simple model which consists of a single linear layer. Using the parameters method, we retrieve the iterator that provides access to all the parameters of the model. Next, we get the first parameter from the iterator. Finally, we get device information of the parameter, which printed to the console.

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(1, 1)
)

device = next(model.parameters()).device

print(device)

The above code will print either cpu or cuda:0 depending on the device the model is currently located on.

Leave a Comment

Cancel reply

Your email address will not be published.