How to run PyTorch on the M1 Mac GPU

2 minute read see also thread comments

As for TensorFlow, it takes only a few steps to enable a Mac with M1 chip aka Apple silicon for performing machine learning tasks in Python using the PyTorch framework. The steps shown in this post are a summary of this blog post by Nikos Kafritsas and this blog post by Sudhanva Narayana (GitHub).

img

Update: PyTorch is now fully available for Apple Silicon. Thus, it’s no longer necessary to follow the instructions below. Just follow the installation instructions in this update post: PyTorch on Apple Silicon.

Pre-check

Again, before we begin, please ensure that you have installed the macOS miniconda ARM version. To check this, activate any existing conda-generated virtual environment, start a Python session and execute:

import platform
platform.platform()

You should receive something like:

  'macOS-12.3-arm64-arm-64bit'

If this is the case, jump to the next section. Otherwise, you need to

  1. uninstall your existing conda installation, and
  2. install the miniconda macOS ARM version, e.g. Miniconda3 macOS Apple M1 64-bit pkg.

Install PyTorch

  1. Create and activate a virtual conda environment:

    conda create --name conda_pytorch python=3.9
    conda activate conda_pytorch
    
  2. Install pip and some other packages, which we will need for the evaluation later:

    conda install pip ipykernel jupyter notebook matplotlib -y
    
  3. Install the PyTorch (Nightly) dependencies:

    pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
    

And that’s it!

To verify that everything is installed correctly, open a Python session and run:

import torch
print(f"torch backend MPS is available? {torch.backends.mps.is_available()}")
print(f"current PyTorch installation built with MPS activated? {torch.backends.mps.is_built()}")
print(f"check the torch MPS backend: {torch.device('mps')}")
print(f"test torch tensor on MPS: {torch.tensor([1,2,3], device='mps')}")

If you get the following responses,

  torch backend MPS is available? True
  current PyTorch installation built with MPS activated? True
  check the torch MPS backend: mps
  test torch tensor on MPS: tensor([1, 2, 3], device='mps:0')

everything is set up well. mps stands for Metal Performance Shader, which is Apple’s GPU architecture.

Note: The M1 GPU support feature is only supported on macOS 12.3 and higher.

Benchmark test

We can benchmark PyTorch using the following code snippet from Nikos’ blog post. To run PyTorch on the M1 GPU, we have to set the device to mps (torch.device("mps")) (for an Nvidia GPU we would set torch.device("cuda") and for running on the CPU torch.device("cpu")):

%%time
import math

dtype = torch.float
device = torch.device("mps")
# alternative:
# device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# create some random input and output data:
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# randomly initialize weights:
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

# Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

The total runtime on my Mac was 2.49s.

3 other articles are linked to this site

PyTorch on Apple Silicon

1 minute read

Already some time ago, PyTorch became fully available for Apple Silicon. It’s no longer necessary to install the nightly b...

A minimal Python installation with miniconda

7 minute read updated:

Learn how to install miniconda to have a quick and minimal Python installation on any operating system. Also learn how to ...

comments