Mạng nơ ron là một mô hình tính toán phổ biến trong deep learning trong bài viết này chúng ta sẽ xây dựng một mạng đơn giản sử dụng module torch.nn
của thư viện PyTorch
Cài đặt pytorch
pip install torch
Các bước
- Nhập tất cả các thư viện cần thiết
- Xác định và khởi tạo mạng nơ-ron
- Đưa dữ liệu vào mô hình để kiểm tra
nhập thư viện
import torch
import torch.nn as nn
import torch.nn.functional as F
Định nghĩa kiến trúc mạng nơ ron
Mạng neural network đơn giản
Với input là một vector features có độ dài 12 và output là vector features có độ dài là 2
class ANNNet(nn.Module):
def __init__(self, input_features = 12, hidden1 = 20, hidden2 = 20, out_features = 2):
super().__init__()
self.f_connected1 = nn.Linear(input_features, hidden1)
self.f_connected2 = nn.Linear(hidden1, hidden2)
self.out = nn.Linear(hidden2, out_features)
def forward(self, x):
x = F.relu(self.f_connected1(x))
x = F.relu(self.f_connected2(x))
x = F.softmax(self.out(x), dim=1)
return x
Ví dụ về mạng CNN trong các tác vụ xử lý hình ảnh
Mạng này input là một ảnh
class CNNNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
# x represents our data
def forward(self, x):
# Pass data through conv1
x = self.conv1(x)
# Use the rectified-linear activation function over x
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
# Run max pooling over x
x = F.max_pool2d(x, 2)
# Pass data through dropout1
x = self.dropout1(x)
# Flatten x with start_dim=1
x = torch.flatten(x, 1)
# Pass data through fc1
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
# Apply softmax to x
output = F.log_softmax(x, dim=1)
return output
Đưa dữ liệu qua mô hình để kiểm tra
# tạo một cái ảnh ngẫu nhiên
random_data = torch.rand((1, 1, 28, 28))
my_nn = CNNNet()
result = my_nn(random_data)
print (result)
Kết quả ngẫu nhiên
tensor([[-2.3025, -2.3728, -2.2621, -2.1108, -2.3175, -2.3395, -2.3442, -2.3291,
-2.2963, -2.3786]], grad_fn=<LogSoftmaxBackward0>)
Tương tự với mạng ANN
random_data = torch.rand((1, 12))
my_nn = ANNNet()
result = my_nn(random_data)
print (result)
Kết quả của mạng ANN
tensor([[0.4591, 0.5409]], grad_fn=<SoftmaxBackward0>)