self_example/TensorFlow_eaxmple/Model_train_test/datadeal/saveData.py

84 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tensorflow as tf
from datadeal import labeled_and_piece, loadData
import numpy as np
# 导入数据,并进行数据处理
# 导入第一类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱齿轮表面磨损故障恒速\DATA')
data1 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱齿轮表面磨损故障恒速\DATA'))
dataWithLabel1 = labeled_and_piece.GetLabel(data1, 0.000, False)
(data1, dataWithLabel1) = dataWithLabel1(data1, 0.000, False)
# 导入第二类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱齿轮齿根裂纹故障恒速\DATA')
data2 = np.array(data0(9, 'E:\DDS_data\平行齿轮箱齿轮齿根裂纹故障恒速\DATA'))
dataWithLabel2 = labeled_and_piece.GetLabel(data2, 1.000, False)
(data2, dataWithLabel2) = dataWithLabel2(data2, 1.000, False)
# 导入第三类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱齿轮断齿故障恒速\DATA')
data3 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱齿轮断齿故障恒速\DATA'))
dataWithLabel3 = labeled_and_piece.GetLabel(data3, 2.000, False)
(data3, dataWithLabel3) = dataWithLabel3(data3, 2.000, False)
# 导入第四类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱齿轮偏心故障恒速\DATA')
data4 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱齿轮偏心故障恒速\DATA'))
dataWithLabel4 = labeled_and_piece.GetLabel(data4, 3.000, False)
(data4, dataWithLabel4) = dataWithLabel4(data4, 3.000, False)
# 导入第五类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱齿轮缺齿故障恒速\DATA')
data5 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱齿轮缺齿故障恒速\DATA'))
dataWithLabel5 = labeled_and_piece.GetLabel(data5, 4.000, False)
(data5, dataWithLabel5) = dataWithLabel5(data5, 4.000, False)
# 导入第六类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱轴承复合故障恒速\DATA')
data6 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱轴承复合故障恒速\DATA'))
dataWithLabel6 = labeled_and_piece.GetLabel(data6, 5.000, False)
(data6, dataWithLabel6) = dataWithLabel6(data6, 5.000, False)
# 导入第七类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱轴承滚动体故障恒速\DATA')
data7 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱轴承滚动体故障恒速\DATA'))
dataWithLabel7 = labeled_and_piece.GetLabel(data7, 6.000, False)
(data7, dataWithLabel7) = dataWithLabel7(data7, 6.000, False)
# 导入第八类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱轴承内圈故障恒速\DATA')
data8 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱轴承内圈故障恒速\DATA'))
dataWithLabel8 = labeled_and_piece.GetLabel(data8, 7.000, False)
(data8, dataWithLabel8) = dataWithLabel8(data8, 7.000, False)
# 导入第九类故障并打标签
data0 = loadData.DataDeal4(9, 'E:\data\DDS_data\平行齿轮箱轴承外圈故障恒速\DATA')
data9 = np.array(data0(9, 'E:\data\DDS_data\平行齿轮箱轴承外圈故障恒速\DATA'))
dataWithLabel9 = labeled_and_piece.GetLabel(data9, 8.000, False)
(data9, dataWithLabel9) = dataWithLabel9(data9, 8.000, False)
# 合并
data_all = tf.concat([data1, data2, data3, data4, data5, data6, data7, data8, data9], axis=0)
label_all = tf.concat(
[dataWithLabel1, dataWithLabel2, dataWithLabel3, dataWithLabel4, dataWithLabel5, dataWithLabel6, dataWithLabel7,
dataWithLabel8, dataWithLabel9], axis=0)
# print("data_all",data_all)
# print("label_all",label_all)
# data_all = np.array(data_all)
# 划分训练集和测试集,并打乱
data_new = labeled_and_piece.PieceAndBag_new(data_all, label_all, False)
(train_data, train_label), (test_data, test_label) = data_new(data_all, label_all, False)
np.save("train_data.npy",train_data)
np.save("train_label.npy",train_label)
np.save("test_data.npy",test_data)
np.save("test_label.npy",test_label)
# confusion_matrix(test_label,y_pred=?) //混淆矩阵的得法
# TSNE.fit_transform() //T-SNE降维表示
'''train_data.shape: (7776, 80, 80, 9)
train_label.shape: (7776, 1)
test_data.shape: (2592, 80, 80, 9)
test_label.shape: (2592, 1)'''