Rylah's Study & Daily Life
PyTorch 01. Basic PyTorch 본문
TorchVision, TorchSummary Install
View > Tool Windows > Terminal
pip install torchvision==0.1.8
pip install torchsummary
import torch
# 입력 뉴런 : 5개, 출력 뉴런 : 3개
# 총 사용 시냅스 15개
hello = torch.nn.Linear(5, 3)
# 임의의 입력 벡터 2개 만들기
data = torch.randn(2, 5)
# 인공 신명망에 입력하고 결과를 출력
print(data)
print(hello(data))
# 필요한 패키지 가져오기
from torchvision.datasets import MNIST # TorchVision에서 MNIST 함수를 가져옴 (배치 설정 등)
import torchvision.transforms as transforms # Torchvision에서 Transform 함수를 수입
from torch.utils.data import DataLoader # Dataloader, 학습용 데이터를 한줄씩 가져올 때 사용
import torch.nn as nn # 인공신경망을 만드는 층 NN 패키지
from torchsummary import summary
# 데이터 변환 방식 지정
rules = transforms.Compose([ # DNN이 PyTorch로 만들어지기 때문에 Pytorch Tensor로 전환
transforms.ToTensor(), # Numpy -> PyTorchTensor
])
# 학습용 데이터 로더
train_loader = DataLoader( # 6만개의 이미지, 한 배치 = 500개씩
MNIST('mnist', train=True, download=True, transform=rules),
batch_size=500, shuffle=True
)
# 평가용 데이터 로더
test_loader = DataLoader(
MNIST('mnist', train=False, download=True, transform=rules),
batch_size=500, shuffle=False
)
# 임의의 이미지 한 batch 가져오기
i = iter(train_loader)
images, labels = i.next()
print(images[0])
print(labels[0])
tensor([[-0.4553, -0.4855, -1.2358, -0.1158, -3.1778],
[-0.9592, 0.4361, 0.6890, 0.6946, 0.4702]])
tensor([[-1.0712, -0.2554, 0.2080],
[ 0.6585, -0.0785, 0.2994]], grad_fn=<AddmmBackward>)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.2314, 0.4706,
0.6431, 0.9961, 0.9961, 0.7020, 0.0902, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0196, 0.2588, 0.5765, 0.9255, 0.9961, 1.0000,
0.9961, 0.9961, 0.9961, 0.9961, 0.8157, 0.0706, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.1216, 0.7647, 0.9961, 0.9961, 0.9961, 0.9961, 0.8078,
0.8039, 0.6863, 0.8510, 0.9961, 0.9961, 0.1529, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2275,
0.5961, 0.9490, 0.9961, 0.9961, 0.9961, 0.6431, 0.1608, 0.0314,
0.0314, 0.0000, 0.1020, 0.8118, 0.7843, 0.0118, 0.1098, 0.2196,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.8078,
0.9961, 0.9961, 0.9961, 0.9647, 0.3451, 0.0196, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0902, 0.3255, 0.0627, 0.9255, 0.4588,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4039, 0.9961,
0.9961, 0.9098, 0.5333, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.2471, 0.9961, 0.0902, 0.0784, 0.9961, 0.4627,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8431, 0.9961,
0.5373, 0.0941, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.1725, 0.8980, 0.9961, 0.7098, 0.0784, 0.9412, 0.2745,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8431, 0.9961,
0.8314, 0.6902, 0.6039, 0.2863, 0.0000, 0.0000, 0.0000, 0.0745,
0.3412, 0.9020, 0.9961, 0.9961, 0.5020, 0.0549, 0.2118, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2706, 0.9490,
0.9961, 0.9961, 0.9961, 0.9882, 0.9216, 0.9216, 0.9216, 0.9412,
0.9961, 0.9961, 0.9961, 0.9333, 0.0314, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2784,
0.7569, 0.9059, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961,
0.9961, 0.9961, 0.9961, 0.4667, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1804, 0.3882, 0.7373, 0.6196, 0.9137,
0.9961, 0.9961, 0.7529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.7922,
0.9961, 0.8510, 0.0980, 0.0000, 0.0000, 0.0000, 0.0314, 0.0314,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6510, 0.9961,
0.9686, 0.4784, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.4392,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1451, 0.9765, 0.9961,
0.7098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6902, 0.4588,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.9961, 0.9961,
0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2784, 0.2824,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9255, 0.9961, 0.7686,
0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.9843, 0.9961, 0.3882,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0275, 0.7843, 0.9961, 0.9961, 0.1176,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0784, 0.9961, 0.9961, 0.4353, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0745,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.2275, 0.9961, 0.9961, 0.2275, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.9255, 0.4588,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]])
tensor(9)
Process finished with exit code 0
'Study > Deep Learning' 카테고리의 다른 글
PyTorch 03. DNN - Heart Disease Dataset (0) | 2022.02.11 |
---|---|
PyTorch 02. MNIST (0) | 2022.02.11 |
Pytorch 00. Hello Pytorch (0) | 2022.02.11 |
[Tensorflow] 04. RNN - AirLine (0) | 2022.01.26 |
[TensorFlow] 03. CNN - Fashion Items (0) | 2022.01.26 |