Source code for easyfl.datasets.cifar10.cifar10
import logging
import os
import torchvision
from easyfl.datasets.simulation import data_simulation
from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR10
from easyfl.datasets.utils.util import save_dict
logger = logging.getLogger(__name__)
[docs]class Cifar10(BaseDataset):
def __init__(self,
root,
fraction,
split_type,
user,
iid_user_fraction=0.1,
train_test_split=0.9,
minsample=10,
num_class=80,
num_of_client=100,
class_per_client=2,
setting_folder=None,
seed=-1,
weights=None,
alpha=0.5):
super(Cifar10, self).__init__(root,
CIFAR10,
fraction,
split_type,
user,
iid_user_fraction,
train_test_split,
minsample,
num_class,
num_of_client,
class_per_client,
setting_folder,
seed)
self.train_data, self.test_data = {}, {}
self.split_type = split_type
self.num_of_client = num_of_client
self.weights = weights
self.alpha = alpha
self.min_size = minsample
self.class_per_client = class_per_client
def download_packaged_dataset_and_extract(self, filename):
pass
def download_raw_file_and_extract(self):
train_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=False, download=True)
self.train_data = {
'x': train_set.data,
'y': train_set.targets
}
self.test_data = {
'x': test_set.data,
'y': test_set.targets
}
def preprocess(self):
train_data_path = os.path.join(self.data_folder, "train")
test_data_path = os.path.join(self.data_folder, "test")
if not os.path.exists(self.data_folder):
os.makedirs(self.data_folder)
if self.weights is None and os.path.exists(train_data_path):
return
logger.info("Start CIFAR10 data simulation")
_, train_data = data_simulation(self.train_data['x'],
self.train_data['y'],
self.num_of_client,
self.split_type,
self.weights,
self.alpha,
self.min_size,
self.class_per_client)
logger.info("Complete CIFAR10 data simulation")
save_dict(train_data, train_data_path)
save_dict(self.test_data, test_data_path)
def convert_data_to_json(self):
pass