58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
#-*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/12/12 12:53
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
import os
|
|
from torch.utils.data import DataLoader, Dataset
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import encoding as ohe
|
|
import setting
|
|
|
|
|
|
class mydataset(Dataset):
|
|
|
|
def __init__(self, folder, transform=None):
|
|
self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)]
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.train_image_file_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
image_root = self.train_image_file_paths[idx]
|
|
image_name = image_root.split(os.path.sep)[-1]
|
|
image = Image.open(image_root)
|
|
if self.transform is not None:
|
|
image = self.transform(image)
|
|
label = ohe.encode(image_name.split('_')[0])
|
|
return image, label
|
|
|
|
|
|
transform = transforms.Compose([
|
|
# transforms.ColorJitter(),
|
|
transforms.Grayscale(),
|
|
transforms.ToTensor(),
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
def get_train_data_loader():
|
|
dataset = mydataset(setting.TRAIN_DATASET_PATH, transform=transform)
|
|
return DataLoader(dataset, batch_size=64, shuffle=True)
|
|
|
|
|
|
def get_eval_data_loader():
|
|
dataset = mydataset(setting.EVAL_DATASET_PATH, transform=transform)
|
|
return DataLoader(dataset, batch_size=1, shuffle=True)
|
|
|
|
|
|
def get_predict_data_loader():
|
|
dataset = mydataset(setting.PREDICT_DATASET_PATH, transform=transform)
|
|
return DataLoader(dataset, batch_size=1, shuffle=True) |