인공지능/부스트캠프 Ai Tech
[Pytorch]4. Dataset & Dataloader
피라_노트
2022. 1. 28. 23:44
728x90
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
In [1]:
from IPython.display import Image
import numpy as np
import torch
from torch import nn
from torch import Tensor
In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
# 초기 데이터 생성 방법 지정
def __init__(self, text, labels):
self.labels = labels
self.data = text
# 데이터의 전체 길이
def __len__(self):
return len(self.labels)
# index 값을 주었을 때 반환되는 데이터의 형태 (X, y)
def __getitem__(self, idx):
label = self.labels[idx]
text = self.data[idx]
sample = {"Text": text, "Class": label}
return sample
DataLoader 클래스¶
- Data의 Batch를 생성해주는 클래스
- Dataset은 하나의 데이터를 어떻게 가져올지를 구현한다면, DataLoader는 Batch size 기준으로 쪼개고 묶는걸 구현한다.
- 학습직전(GPU feed전) 데이터의 변환(Tensor로)을 책임
- Tensor로 변환 + Batch 처리가 메인 업무
- 병렬적인 데이터 전처리(CPU 작업과 GPU 작업을 동시에 하게 만듬)의 고민이 필요하다.
In [3]:
# Dataset 생성
text = ['Happy', 'Amazing', 'Sad', 'Unhappy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
MyDataset = CustomDataset(text, labels)
In [4]:
type(MyDataset)
Out[4]:
__main__.CustomDataset
In [5]:
MyDataLoder = DataLoader(MyDataset, batch_size=2, shuffle=True)
next(iter(MyDataLoder))
Out[5]:
{'Text': ['Glum', 'Happy'], 'Class': ['Negative', 'Positive']}
In [7]:
next(iter(MyDataLoder))
Out[7]:
{'Text': ['Unhappy', 'Happy'], 'Class': ['Negative', 'Positive']}
In [16]:
for dataset in MyDataLoder:
# batch_size가 2, 2,2,1
print(dataset)
{'Text': ['Amazing', 'Happy'], 'Class': ['Positive', 'Positive']} {'Text': ['Sad', 'Glum'], 'Class': ['Negative', 'Negative']} {'Text': ['Unhappy'], 'Class': ['Negative']}
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, work_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
DataLoader 참고 : https://subinium.github.io/pytorch-dataloader/
cllate_fn : Variable List를 처리할때 많이 사용
NotMNIST 데이터로 학습하기
- NotMNIST 다운로드 자동화까지
In [2]:
from torchvision.datasets import VisionDataset
from typing import Any, Callable, Dict, List, Optional, Tuple
import os
from tqdm import tqdm
import os
import sys
from pathlib import Path
import requests
from skimage import io, transform
import matplotlib.pyplot as plt
In [3]:
import tarfile
class NotMNIST(VisionDataset):
resource_url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz'
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False, # download 여부를 사용자가 설정 가능
) -> None:
super(NotMNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
if not self._check_exists() or download: # download가 존재하니?
self.download()
self.data, self.targets = self._load_data()
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_name = self.data[index]
image = io.imread(image_name)
label = self.targets[index]
if self.transform: # 이미지를 일고 transform으로 보낸다.
image = self.transform(image)
return image, label
def _load_data(self):
filepath = self.image_folder
data = []
targets = []
for target in os.listdir(filepath): # 해당 Path에서 List 생성
filenames = [os.path.abspath(
os.path.join(filepath, target, x)) for x in os.listdir(
os.path.join(filepath, target))]
targets.extend([target] * len(filenames))
data.extend(filenames)
return data, targets
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def image_folder(self) -> str:
return os.path.join(self.root, 'notMNIST_large')
def download(self) -> None:
os.makedirs(self.raw_folder, exist_ok=True)
fname = self.resource_url.split("/")[-1]
chunk_size = 1024
headers = {
"User-Agent": f"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
f"AppleWebKit/537.36 (KHTML, like Gecko) "
f"Chrome/80.0.3987.122 Safari/537.36"
}
r = requests.get(self.resource_url,
headers=headers)
filesize = int(r.headers["Content-Length"])
with requests.get(self.resource_url, stream=True, headers=headers) as r, open(
os.path.join(self.raw_folder, fname), "wb") as f, tqdm(
unit="B", # unit string to be displayed.
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=filesize, # the total iteration.
file=sys.stdout, # default goes to stderr, this is the display on console.
desc=fname # prefix to be displayed on progress bar.
) as progress:
for chunk in r.iter_content(chunk_size=chunk_size):
# download the file chunk by chunk
datasize = f.write(chunk)
# on each chunk update the progress bar.
progress.update(datasize)
# 압축 해제
self._extract_file(os.path.join(self.raw_folder, fname), target_path=self.root)
def _extract_file(self, fname, target_path) -> None:
if fname.endswith("tar.gz"):
tag = "r:gz"
elif fname.endswith("tar"):
tag = "r:"
tar = tarfile.open(fname, tag)
tar.extractall(path=target_path)
tar.close()
def _check_exists(self) -> bool:
return os.path.exists(self.raw_folder)
In [ ]:
dataset = NotMNIST("data", download=True)
In [ ]:
fig = plt.figure()
for i in range(8):
sample = dataset[i]
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(sample[0])
if i == 3:
plt.show()
break
In [ ]:
In [ ]: