Source code for jaclearn.datasets.image_classification.cifar
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : cifar.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os.path as osp
import functools
import pickle
import tarfile
import numpy as np
from jacinle.io.network import download
__all__ = ['load_cifar', 'load_cifar10', 'load_cifar100']
cifar_web_address = 'http://www.cs.toronto.edu/~kriz/'
def _read_cifar(filenames, cls):
image = []
label = []
for fname in filenames:
with open(fname, 'rb') as f:
raw_dict = pickle.load(f, encoding='latin1')
raw_data = raw_dict['data']
label.extend(raw_dict['labels' if cls == 10 else 'fine_labels'])
for x in raw_data:
x = x.reshape(3, 32, 32)
x = np.transpose(x, [1, 2, 0])
image.append(x)
return np.array(image), np.array(label)
[docs]
def load_cifar(data_dir, nr_classes=10):
assert nr_classes in (10, 100)
data_file = 'cifar-{}-python.tar.gz'.format(nr_classes)
origin = cifar_web_address + data_file
dataset = osp.join(data_dir, data_file)
if nr_classes == 10:
folder_name = 'cifar-10-batches-py'
filenames = ['data_batch_{}'.format(i) for i in range(1, 6)]
filenames.append('test_batch')
else:
folder_name = 'cifar-100-python'
filenames = ['train', 'test']
if not osp.isdir(osp.join(data_dir, folder_name)):
if not osp.isfile(dataset):
download(origin, data_dir, data_file)
tarfile.open(dataset, 'r:gz').extractall(data_dir)
filenames = list(map(lambda x: osp.join(data_dir, folder_name, x), filenames))
train_set = _read_cifar(filenames[:-1], nr_classes)
test_set = _read_cifar([filenames[-1]], nr_classes)
return train_set, test_set
load_cifar10 = functools.partial(load_cifar, nr_classes=10)
load_cifar100 = functools.partial(load_cifar, nr_classes=100)