usps手写数据集+使用代码.zip


数据集为usps手写数据集(.mat形式),共9298张图片,维度16*16,内附有python版的使用代码
资源截图
代码片段和文件信息
import numpy as np  #用于数据处理
import matplotlib.pyplot as plt  # 用于展示图片
import scipy.io as sio  # 用于读取.mat

def load_dataset(dataset=‘usps‘):
    # 加载usps数据集
    if dataset == ‘usps‘:
        data = sio.loadmat(‘usps_resampled.mat‘)
        x_train y_train x_test y_test = data[‘train_patterns‘].T data[‘train_labels‘].T data[‘test_patterns‘].T data[‘test_labels‘].T
        x = np.concatenate((x_train x_test))
        y_train = [np.argmax(l) for l in y_train]  # 将onehot编码转成一般编码
        y_test = [np.argmax(l) for l in y_test]  # 将onehot编码转成一般编码
        y = np.concatenate((np.array(y_train) np.array(y_test))).astype(np.int32)
        x = x.reshape((-1 16 16 1)).astype(np.float32)   # 便于使用卷积层
        # x = x.reshape((x.shape[0] 16*16)).astype(np.float32)   # 便于使用全连接层
        x = np.divide(x 255.)  # 归一化
        print(‘USPS:‘ x.shape y.shape)  # (9298 16 16 1)
        return x y

    else:
        print(‘The dataSet name is useless‘)
        exit(0)


def show_figure(data):  # 显示前200张图片
    digit_size = data.shape[1]  # 16 或者 28
    data = np.squeeze(data)  # 去掉1维
    figure = np.zeros((digit_size * 10 digit_size * 20))
    t = 0
    for i in range(10):  # 10行
        for j in range(20):  # 每行展示20个数据
            figure[i * digit_size: (i+1) * digit_size j * digit_size: (j+1) * digit_size] = data[t]
            t = t + 1
    plt.figure(figsize=(15 15))
    plt.imshow(figure)
    plt.show()

if __name__ == ‘__main__‘:
    # load dataset
    x y = load_dataset(‘usps‘)
    print(y[:200])  # 展示前200个样本的标签
    show_figure(x)  # 展示前200个样本数据

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2020-05-25 10:11  usps手写数据集+使用代码
     文件        1759  2020-05-25 10:14  usps手写数据集+使用代码 est_usps.py
     文件    19228688  2006-03-13 20:48  usps手写数据集+使用代码usps_resampled.mat

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件举报,一经查实,本站将立刻删除。

发表评论

评论列表(条)