Rylah's Study & Daily Life

PyTorch 01. Basic PyTorch 본문

Study/Deep Learning

PyTorch 01. Basic PyTorch

Rylah 2022. 2. 11. 10:09

TorchVision, TorchSummary Install

View > Tool Windows > Terminal

 

pip install torchvision==0.1.8

 pip install torchsummary

 

import torch

# 입력 뉴런 : 5개, 출력 뉴런 : 3개
# 총 사용 시냅스 15개

hello = torch.nn.Linear(5, 3)

# 임의의 입력 벡터 2개 만들기
data = torch.randn(2, 5)

# 인공 신명망에 입력하고 결과를 출력
print(data)
print(hello(data))

# 필요한 패키지 가져오기
from torchvision.datasets import MNIST # TorchVision에서 MNIST 함수를 가져옴 (배치 설정 등)
import torchvision.transforms as transforms # Torchvision에서 Transform 함수를 수입
from torch.utils.data import DataLoader # Dataloader, 학습용 데이터를 한줄씩 가져올 때 사용
import torch.nn as nn # 인공신경망을 만드는 층 NN 패키지
from torchsummary import summary

# 데이터 변환 방식 지정
rules = transforms.Compose([ # DNN이 PyTorch로 만들어지기 때문에 Pytorch Tensor로 전환
    transforms.ToTensor(), # Numpy -> PyTorchTensor
])

# 학습용 데이터 로더
train_loader = DataLoader( # 6만개의 이미지, 한 배치 = 500개씩
    MNIST('mnist', train=True, download=True, transform=rules),
    batch_size=500, shuffle=True
)

# 평가용 데이터 로더
test_loader = DataLoader(
    MNIST('mnist', train=False, download=True, transform=rules),
    batch_size=500, shuffle=False
)


# 임의의 이미지 한 batch 가져오기
i = iter(train_loader)
images, labels = i.next()

print(images[0])
print(labels[0])

tensor([[-0.4553, -0.4855, -1.2358, -0.1158, -3.1778],
        [-0.9592,  0.4361,  0.6890,  0.6946,  0.4702]])
tensor([[-1.0712, -0.2554,  0.2080],
        [ 0.6585, -0.0785,  0.2994]], grad_fn=<AddmmBackward>)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.2314, 0.4706,
          0.6431, 0.9961, 0.9961, 0.7020, 0.0902, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0196, 0.2588, 0.5765, 0.9255, 0.9961, 1.0000,
          0.9961, 0.9961, 0.9961, 0.9961, 0.8157, 0.0706, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1216, 0.7647, 0.9961, 0.9961, 0.9961, 0.9961, 0.8078,
          0.8039, 0.6863, 0.8510, 0.9961, 0.9961, 0.1529, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2275,
          0.5961, 0.9490, 0.9961, 0.9961, 0.9961, 0.6431, 0.1608, 0.0314,
          0.0314, 0.0000, 0.1020, 0.8118, 0.7843, 0.0118, 0.1098, 0.2196,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.8078,
          0.9961, 0.9961, 0.9961, 0.9647, 0.3451, 0.0196, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0902, 0.3255, 0.0627, 0.9255, 0.4588,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4039, 0.9961,
          0.9961, 0.9098, 0.5333, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.2471, 0.9961, 0.0902, 0.0784, 0.9961, 0.4627,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8431, 0.9961,
          0.5373, 0.0941, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1725, 0.8980, 0.9961, 0.7098, 0.0784, 0.9412, 0.2745,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8431, 0.9961,
          0.8314, 0.6902, 0.6039, 0.2863, 0.0000, 0.0000, 0.0000, 0.0745,
          0.3412, 0.9020, 0.9961, 0.9961, 0.5020, 0.0549, 0.2118, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2706, 0.9490,
          0.9961, 0.9961, 0.9961, 0.9882, 0.9216, 0.9216, 0.9216, 0.9412,
          0.9961, 0.9961, 0.9961, 0.9333, 0.0314, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2784,
          0.7569, 0.9059, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961,
          0.9961, 0.9961, 0.9961, 0.4667, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.1804, 0.3882, 0.7373, 0.6196, 0.9137,
          0.9961, 0.9961, 0.7529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.7922,
          0.9961, 0.8510, 0.0980, 0.0000, 0.0000, 0.0000, 0.0314, 0.0314,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6510, 0.9961,
          0.9686, 0.4784, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.4392,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1451, 0.9765, 0.9961,
          0.7098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6902, 0.4588,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.9961, 0.9961,
          0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2784, 0.2824,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9255, 0.9961, 0.7686,
          0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.9843, 0.9961, 0.3882,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0275, 0.7843, 0.9961, 0.9961, 0.1176,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0784, 0.9961, 0.9961, 0.4353, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0745,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.2275, 0.9961, 0.9961, 0.2275, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.9255, 0.4588,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])
tensor(9)

Process finished with exit code 0

'Study > Deep Learning' 카테고리의 다른 글

PyTorch 03. DNN - Heart Disease Dataset  (0) 2022.02.11
PyTorch 02. MNIST  (0) 2022.02.11
Pytorch 00. Hello Pytorch  (0) 2022.02.11
[Tensorflow] 04. RNN - AirLine  (0) 2022.01.26
[TensorFlow] 03. CNN - Fashion Items  (0) 2022.01.26