数据集为usps手写数据集(.mat形式),共9298张图片,维度16*16,内附有python版的使用代码
代码片段和文件信息
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2020-05-25 10:11 usps手写数据集+使用代码
文件 1759 2020-05-25 10:14 usps手写数据集+使用代码 est_usps.py
文件 19228688 2006-03-13 20:48 usps手写数据集+使用代码usps_resampled.mat
import numpy as np #用于数据处理
import matplotlib.pyplot as plt # 用于展示图片
import scipy.io as sio # 用于读取.mat
def load_dataset(dataset=‘usps‘):
# 加载usps数据集
if dataset == ‘usps‘:
data = sio.loadmat(‘usps_resampled.mat‘)
x_train y_train x_test y_test = data[‘train_patterns‘].T data[‘train_labels‘].T data[‘test_patterns‘].T data[‘test_labels‘].T
x = np.concatenate((x_train x_test))
y_train = [np.argmax(l) for l in y_train] # 将onehot编码转成一般编码
y_test = [np.argmax(l) for l in y_test] # 将onehot编码转成一般编码
y = np.concatenate((np.array(y_train) np.array(y_test))).astype(np.int32)
x = x.reshape((-1 16 16 1)).astype(np.float32) # 便于使用卷积层
# x = x.reshape((x.shape[0] 16*16)).astype(np.float32) # 便于使用全连接层
x = np.divide(x 255.) # 归一化
print(‘USPS:‘ x.shape y.shape) # (9298 16 16 1)
return x y
else:
print(‘The dataSet name is useless‘)
exit(0)
def show_figure(data): # 显示前200张图片
digit_size = data.shape[1] # 16 或者 28
data = np.squeeze(data) # 去掉1维
figure = np.zeros((digit_size * 10 digit_size * 20))
t = 0
for i in range(10): # 10行
for j in range(20): # 每行展示20个数据
figure[i * digit_size: (i+1) * digit_size j * digit_size: (j+1) * digit_size] = data[t]
t = t + 1
plt.figure(figsize=(15 15))
plt.imshow(figure)
plt.show()
if __name__ == ‘__main__‘:
# load dataset
x y = load_dataset(‘usps‘)
print(y[:200]) # 展示前200个样本的标签
show_figure(x) # 展示前200个样本数据
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2020-05-25 10:11 usps手写数据集+使用代码
文件 1759 2020-05-25 10:14 usps手写数据集+使用代码 est_usps.py
文件 19228688 2006-03-13 20:48 usps手写数据集+使用代码usps_resampled.mat
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件举报,一经查实,本站将立刻删除。
评论列表(条)