pytoch M2芯片测试

news/2024/7/9 8:53:51 标签: sklearn

今天才发现我的新片是M2芯片,而不是M1芯片,有点尴尬
在这里插入图片描述
参考网址
https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/

测试结果如下

M2_cpu.py

# https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device
device = "cpu"
device = torch.device(device)
print(f"Using device: {device}")


# Define the CNN model
class HandwritingRecognitionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Define the convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Define the pooling and dropout layers
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Define the fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass the input through the convolutional layers
        x = self.conv1(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout2(x)

        # Reshape the output for the fully connected layers
        x = x.view(-1, 32 * 7 * 7)

        # Pass the output through the fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)

        # Return the final output
        return x


# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())

# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
model = HandwritingRecognitionModel().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):
    # Set the model to training mode
    model.train()

    # Iterate over the training data
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        # Pass the input through the model
        outputs = model(images)

        # Compute the loss
        loss = loss_fn(outputs, labels)

        # Backpropagate the error
        loss.backward()

        # Update the model parameters
        optimizer.step()

    # Set the model to evaluation mode
    model.eval()

    # Evaluate the model on the validation set
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Pass the input through the model
            outputs = model(images)

            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述

M2_MPS.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device(device)
print(f"Using device: {device}")


# Define the CNN model
class HandwritingRecognitionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Define the convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Define the pooling and dropout layers
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Define the fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass the input through the convolutional layers
        x = self.conv1(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout2(x)

        # Reshape the output for the fully connected layers
        x = x.view(-1, 32 * 7 * 7)

        # Pass the output through the fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)

        # Return the final output
        return x


# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())

# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
model = HandwritingRecognitionModel().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):
    # Set the model to training mode
    model.train()

    # Iterate over the training data
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        # Pass the input through the model
        outputs = model(images)

        # Compute the loss
        loss = loss_fn(outputs, labels)

        # Backpropagate the error
        loss.backward()

        # Update the model parameters
        optimizer.step()

    # Set the model to evaluation mode
    model.eval()

    # Evaluate the model on the validation set
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Pass the input through the model
            outputs = model(images)

            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述


http://www.niftyadmin.cn/n/5081592.html

相关文章

React 状态管理 - Context API 前世今生(上)旧版v16.3前

目录 扩展学习资料 Context api before React v16.3 Context 实战使用-Context Context VS Props Context Props Context的缺陷 New Context API 的实践 扩展学习资料 名称 链接 备注 new context api https://reactjs.org/docs/context.html 英文 old context …

class类实现Serializable接口生成serialVersionUID

前言 我在class类实现了Serializable接口,发现把鼠标放在这个类名上,然后键盘输入altenter键没有生成serialVersionUID的提示 解决 找到Editor下边的Inspections,然后搜索UID,把如下截图中的勾选即可 效果 鼠标光标放在类名上&am…

记录用命令行将项目打包成war包

记录用命令行将项目打包成war包 找到项目的pom.xml 在当前路径下进入cmd 输入命令 mvn clean package 发现报错了 Failed to execute goal org.apache.maven.plugins:maven-war-plugin:2.2:war (default-war) on project MMS: Error assembling WAR: webxml attribute is req…

使用wireshark解析ipsec esp包

Ipsec esp包就是ipsec通过ike协议协商好后建立的通信隧道使用的加密包,该加密包里面就是用户的数据,比如通过的语音等。 那么如何将抓出来的esp包解析出来看呢? 获取相关的esp的key信息. 打开wireshark -> edit->preferences 找到pr…

Zookeeper经典应用场景实战

1. Zookeeper Java客户端实战 ZooKeeper应用的开发主要通过Java客户端API去连接和操作ZooKeeper集群。可供选择的Java客户端API有: ZooKeeper官方的Java客户端API。第三方的Java客户端API,比如Curator。 ZooKeeper官方的客户端API提供了基本的操作。例…

工程物料管理信息化建设(十二)——关于工程物料管理系统最后的思考

目录 1 功能回顾1.1 MTO模块1.2 请购模块1.3 采购模块1.4 催交模块1.5 现场管理模块1.6 数据分析和看板模块1.7 其它模块 2 最后几个问题2.1 按管线发料和直接发料重叠2.2 YHA 材料编码的唯一性问题2.3 “合同量单-箱单-入库单” 数据映射 3 关于未来的思考3.1 三个专业之间的关…

VSCODE+PHP8.2配置踩坑记录

VSCODEPHP8.2配置踩坑记录 – WhiteNights Site 我配置过的最恶心的环境之一:windows上的php。另一个是我centos服务器上的php。 进不了断点 端口配置和xdebug的安装 这个应该是最常见的问题了。从网上下载完php并解压到本地,打开vscode,安装…

开发者职场“生存状态”大调研报告分析 - 第一版

听人劝、吃饱饭,奉劝各位小伙伴,不要订阅该文所属专栏。 作者:不渴望力量的哈士奇(哈哥),十余年工作经验, 跨域学习者,从事过全栈研发、产品经理等工作,现任研发部门 CTO 。荣誉:2022年度博客之星Top4、博客专家认证、全栈领域优质创作者、新星计划导师,“星荐官共赢计…