Source code for jaclearn.datasets.image_classification.mnist
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : mnist.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os.path as osp
import gzip
import pickle
from jacinle.io.network import download
__all__ = ['load_mnist']
[docs]
def load_mnist(
data_dir,
data_file='mnist.pkl.gz',
origin='http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'):
dataset = osp.join(data_dir, data_file)
if (not osp.isfile(dataset)) and data_file == 'mnist.pkl.gz':
download(origin, data_dir, data_file)
# Load the dataset
with gzip.open(dataset, 'rb') as f:
try:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
except:
train_set, valid_set, test_set = pickle.load(f)
return train_set, valid_set, test_set