16 lines
951 B
Python
16 lines
951 B
Python
import tensorflow_datasets as tfds
|
||
splits=tfds.Split.TRAIN.subsplit(weighted=[2,1,1])
|
||
# tfds.Split.TRAIN.subsplit函数按传入的权重将其分成为训练集占50%,验证集占25%,测试集占25%
|
||
# metadata属性用于获取MNIST数据集的基本信息,包括数据的种类,大小以及对应的形式
|
||
(raw_trian,raw_validation,raw_test),metadata=tfds.load('mnist',split=list(splits),with_info=True,as_supervised=True)
|
||
import tensorflow as tf
|
||
# load函数中添加split参数,表示将数据在传入的时候直接进行分割,按数据的类型分割成“image”和“model”
|
||
# batch_size=-1,可以从返回的tf.Tensor对象获取NumPy数组中完整数据集
|
||
minst_data=tfds.load('mnist',split=tfds.Split.TRAIN,batch_size=-1)
|
||
numpy_ds=tfds.as_numpy(minst_data)
|
||
numpy_image,numpy_label=numpy_ds["image"],numpy_ds["label"]
|
||
'''minst_train,minst_test=minst_data['train'],minst_data['test']
|
||
|
||
print(minst_train)
|
||
print(minst_test'''
|