20231108爬虫学习更新
This commit is contained in:
parent
c1c870c74e
commit
df0f410f8f
|
|
@ -0,0 +1,164 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 19:09
|
||||
@Usage :
|
||||
@Desc : 一个基本的练习:爬取 https://ssr1.scrape.center 电影描述以及detail等
|
||||
'''
|
||||
|
||||
import requests
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urljoin
|
||||
from os import makedirs
|
||||
from os.path import exists
|
||||
from pyquery import PyQuery as pq
|
||||
import json
|
||||
|
||||
# 输出的日志级别
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s: %(message)s')
|
||||
BASE_URL = 'https://ssr1.scrape.center'
|
||||
TOTAL_PAGE = 10
|
||||
|
||||
RESULTS_DIR = 'results'
|
||||
exists(RESULTS_DIR) or makedirs(RESULTS_DIR)
|
||||
|
||||
|
||||
def scrape_page(url):
|
||||
"""
|
||||
scrape page by url and return its html
|
||||
:param url: page url
|
||||
:return: html of page
|
||||
"""
|
||||
logging.info('scraping %s...', url)
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.text
|
||||
logging.error('get invalid status code %s while scraping %s', response.status_code, url)
|
||||
except requests.RequestException:
|
||||
logging.error('error occurred while scraping %s', url, exc_info=True)
|
||||
|
||||
|
||||
def scrape_index(page):
|
||||
index_url = f'{BASE_URL}/page/{page}'
|
||||
return scrape_page(index_url)
|
||||
|
||||
|
||||
# 从页面HTML中获取详情页坐标
|
||||
def parse_index(html):
|
||||
pattern = '<a.*?href="(.*?)".*?class="name">'
|
||||
items = re.findall(pattern, html, re.S)
|
||||
if not items:
|
||||
return []
|
||||
for item in items:
|
||||
detail_url = urljoin(BASE_URL, item)
|
||||
logging.info('get detail url %s', detail_url)
|
||||
yield detail_url
|
||||
|
||||
|
||||
def scrape_detail(url):
|
||||
return scrape_page(url)
|
||||
|
||||
|
||||
def parse_detail(html):
|
||||
cover_pattern = re.compile(
|
||||
'class="item.*?<img.*?src="(.*?)".*?class="cover">', re.S)
|
||||
name_pattern = re.compile('<h2.*?>(.*?)</h2>')
|
||||
categories_pattern = re.compile(
|
||||
'<button.*?category.*?<span>(.*?)</span>.*?</button>', re.S)
|
||||
published_at_pattern = re.compile('(\d{4}-\d{2}-\d{2})\s?上映')
|
||||
drama_pattern = re.compile('<div.*?drama.*?>.*?<p.*?>(.*?)</p>', re.S)
|
||||
score_pattern = re.compile('<p.*?score.*?>(.*?)</p>', re.S)
|
||||
|
||||
cover = re.search(cover_pattern, html).group(
|
||||
1).strip() if re.search(cover_pattern, html) else None
|
||||
name = re.search(name_pattern, html).group(
|
||||
1).strip() if re.search(name_pattern, html) else None
|
||||
categories = re.findall(categories_pattern, html) if re.findall(
|
||||
categories_pattern, html) else []
|
||||
published_at = re.search(published_at_pattern, html).group(
|
||||
1) if re.search(published_at_pattern, html) else None
|
||||
drama = re.search(drama_pattern, html).group(
|
||||
1).strip() if re.search(drama_pattern, html) else None
|
||||
score = float(re.search(score_pattern, html).group(1).strip()
|
||||
) if re.search(score_pattern, html) else None
|
||||
return {
|
||||
'cover': cover,
|
||||
'name': name,
|
||||
'categories': categories,
|
||||
'published_at': published_at,
|
||||
'drama': drama,
|
||||
'score': score
|
||||
}
|
||||
|
||||
|
||||
def parse_detailByPyQuery(html):
|
||||
doc = pq(html)
|
||||
cover = doc('img.cover').attr('src')
|
||||
name = doc('a > h2').text()
|
||||
categories = [item.text() for item in doc('.categories button span').items()]
|
||||
published_at = doc('.info:contains(上映)').text()
|
||||
published_at = re.search('(\d{4}-\d{2}-\d{2})', published_at).group(1) \
|
||||
if published_at and re.search('\d{4}-\d{2}-\d{2}', published_at) else None
|
||||
drama = doc('.drama p').text()
|
||||
score = doc('p.score').text()
|
||||
score = float(score) if score else None
|
||||
return {
|
||||
'cover': cover,
|
||||
'name': name,
|
||||
'categories': categories,
|
||||
'published_at': published_at,
|
||||
'drama': drama,
|
||||
'score': score
|
||||
}
|
||||
|
||||
|
||||
def save_data(data):
|
||||
"""
|
||||
save to json file
|
||||
:param data:
|
||||
:return:
|
||||
"""
|
||||
name = data.get('name')
|
||||
data_path = f'{RESULTS_DIR}/{name}.json'
|
||||
# ensure_ascii值为False,可以保证中文字符在文件内能以正常的正文文本呈现,而不是unicode
|
||||
# indent为2,可以使json可以两行缩进
|
||||
json.dump(data, open(data_path, 'w', encoding='utf-8'),
|
||||
ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
for page in range(1, TOTAL_PAGE):
|
||||
index_html = scrape_index(page)
|
||||
detail_urls = parse_index(index_html)
|
||||
for detail_url in detail_urls:
|
||||
data = parse_detail(scrape_detail(detail_url))
|
||||
logging.info("get detail data %s", data)
|
||||
save_data(data)
|
||||
logging.info("save data successfully")
|
||||
|
||||
|
||||
def mainByMulti(page):
|
||||
index_html = scrape_index(page)
|
||||
detail_urls = parse_index(index_html)
|
||||
for detail_url in detail_urls:
|
||||
data = parse_detail(scrape_detail(detail_url))
|
||||
logging.info("get detail data %s", data)
|
||||
save_data(data)
|
||||
logging.info("save data successfully")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 单进程
|
||||
# main()
|
||||
# 多进程
|
||||
import multiprocessing
|
||||
|
||||
pool = multiprocessing.Pool()
|
||||
pages = range(1, TOTAL_PAGE)
|
||||
pool.map(mainByMulti, pages)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 18:59
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 19:00
|
||||
@Usage : httpx库学习;request不支持http2.0相关协议。需要使用httpx。
|
||||
@Desc : httpx用法与request类似,具体内容参考:https://github.com/Python3WebSpider/HttpxTest
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 15:51
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 15:51
|
||||
@Usage :
|
||||
@Desc : request库学习
|
||||
'''
|
||||
|
||||
import requests
|
||||
import re
|
||||
from requests.packages import urllib3
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests_oauthlib import OAuth1
|
||||
|
||||
'''
|
||||
基本get使用
|
||||
'''
|
||||
|
||||
|
||||
def get():
|
||||
data = {
|
||||
'name': 'germey',
|
||||
'age': 25
|
||||
}
|
||||
response = requests.get('https://httpbin.org/get', params=data)
|
||||
|
||||
print(response.text)
|
||||
print(type(response.json())) # dict
|
||||
print(response.json())
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
抓取网页:使用模式匹配,抓取标题
|
||||
'''
|
||||
|
||||
|
||||
def getPattern():
|
||||
response = requests.get('https://ssr1.scrape.center/')
|
||||
pattern = '<h2.*?>(.*?)</h2>'
|
||||
pattern = re.compile(pattern, re.S)
|
||||
titles = re.findall(pattern, response.text)
|
||||
|
||||
print(titles)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
抓取二进制数据:使用模式匹配,抓取标题
|
||||
'''
|
||||
|
||||
|
||||
def getBinary():
|
||||
response = requests.get('https://scrape.center/favicon.ico')
|
||||
print(response.text)
|
||||
print(response.content)
|
||||
|
||||
with open('favicon.ico', 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
基本response的相关参数
|
||||
'''
|
||||
|
||||
|
||||
def getResponse():
|
||||
response = requests.get('https://ssr1.scrape.center/')
|
||||
|
||||
print(type(response.status_code), response.status_code)
|
||||
print(type(response.headers), response.headers)
|
||||
print(type(response.cookies), response.cookies)
|
||||
print(type(response.history), response.history)
|
||||
|
||||
exit() if not response.status_code == requests.codes.ok else print('Request Success!')
|
||||
|
||||
|
||||
'''
|
||||
基本post使用
|
||||
'''
|
||||
|
||||
|
||||
def post():
|
||||
data = {
|
||||
'name': 'germey',
|
||||
'age': 25
|
||||
}
|
||||
response = requests.post('https://httpbin.org/get', data=data)
|
||||
|
||||
print(response.text)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
高级用法:上传文件
|
||||
'''
|
||||
|
||||
|
||||
def postFile():
|
||||
file = {
|
||||
'file': open('favicon.ico', 'rb')
|
||||
}
|
||||
response = requests.post('https://httpbin.org/post', files=file)
|
||||
|
||||
print(response.text)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
高级用法:cookie
|
||||
cookie成功模拟了登录状态,这样就能爬取登录之后才能看到的页面了
|
||||
'''
|
||||
|
||||
|
||||
def postCookie():
|
||||
response = requests.get('https://www.baidu.com')
|
||||
print(response.cookies)
|
||||
|
||||
for key, value in response.cookies.items():
|
||||
print(key, "=", value)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
'''
|
||||
Session维持:
|
||||
如果第一次请求利用request库的post方法登录了某个网站,第二次想获取成功登录后自己的个人信息,
|
||||
于是又用requests库的get方法区请求个人信息页面,这实际上相当于打开了两个浏览器,是两个完全独立的操作,这时需要维持Session
|
||||
'''
|
||||
|
||||
|
||||
def session():
|
||||
s = requests.Session()
|
||||
s.get("https://www.httpbin.org/cookies/set/number/123456789")
|
||||
r = s.get('https://www.httpbin.org/cookies')
|
||||
print(r.text) # {"cookies": {"number": "123456789"}}
|
||||
|
||||
|
||||
'''
|
||||
SSL证书验证:
|
||||
有些网站的HTTPS证书可能并不被CA机构认可,出现SSL证书错误
|
||||
'''
|
||||
|
||||
|
||||
def SSL():
|
||||
# response = requests.get("https://ssr2.scrape.center/")
|
||||
# print(response.status_code) # requests.exceptions.SSLError: HTTPSConnectionPool(host='ssr2.scrape.center', port=443): Max retries exceeded with url
|
||||
|
||||
urllib3.disable_warnings()
|
||||
response = requests.get("https://ssr2.scrape.center/", verify=False)
|
||||
print(response.status_code) # 200
|
||||
|
||||
|
||||
'''
|
||||
超时验证:
|
||||
防止服务器不能即时响应
|
||||
'''
|
||||
|
||||
|
||||
def timeout():
|
||||
# 设置超时时间1秒
|
||||
response = requests.get("https://www.httpbin.org/get", timeout=1)
|
||||
# 如果不设置,则永久等待,如果设置为timeout=(5,30)则连接超时时间5秒,读取超时时间30秒
|
||||
print(response.status_code)
|
||||
|
||||
|
||||
'''
|
||||
身份认证:
|
||||
在访问启用了基本身份认证的网站时,首先会弹出一个认证窗口
|
||||
'''
|
||||
|
||||
|
||||
def Auth():
|
||||
# 用户名和密码都是admin
|
||||
response = requests.get("https://ssr3.scrape.center/", auth=('admin', 'admin'))
|
||||
|
||||
print(response.status_code)
|
||||
|
||||
|
||||
'''
|
||||
request还提供了其他认证方式,如OAuth认证,不过此时需要安装requests_oauthlib包
|
||||
'''
|
||||
def OAuth():
|
||||
# 用户名和密码都是admin
|
||||
url = 'https://api.twitter.com/1.1/account/verify_credentials.json'
|
||||
auth = OAuth('your_app_key', 'your_app_sercet', 'user_oauth_token', 'user_oauth_token_secret')
|
||||
response = requests.get(url, auth=auth)
|
||||
|
||||
print(response.status_code)
|
||||
|
||||
|
||||
|
||||
'''
|
||||
代理设置:
|
||||
某些网站请求几次可以正常获取内容。但一旦开始大规模爬取,可能弹出验证码或者挑战到登录认证页面等
|
||||
可以使用代理来解决这个问题
|
||||
'''
|
||||
def proxy():
|
||||
# 用户名和密码都是admin
|
||||
url = 'https://api.twitter.com/1.1/account/verify_credentials.json'
|
||||
|
||||
proxies ={
|
||||
'http':'http://10.10.10.10:1080',
|
||||
'https':'http://user:password@10.10.10.10:1080/'
|
||||
}
|
||||
|
||||
response = requests.get(url, proxy=proxies)
|
||||
|
||||
print(response.status_code)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Auth()
|
||||
|
|
@ -0,0 +1,953 @@
|
|||
<html lang="en">
|
||||
<head>
|
||||
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<link rel="icon" href="/static/img/favicon.ico">
|
||||
<title>Scrape | Movie</title>
|
||||
|
||||
|
||||
<link href="/static/css/app.css" type="text/css" rel="stylesheet">
|
||||
|
||||
<link href="/static/css/index.css" type="text/css" rel="stylesheet">
|
||||
|
||||
</head>
|
||||
<body>
|
||||
<div id="app">
|
||||
<div data-v-74e8b908="" class="el-row" id="header">
|
||||
<div data-v-74e8b908="" class="container el-col el-col-18 el-col-offset-3">
|
||||
<div data-v-74e8b908="" class="el-row">
|
||||
<div data-v-74e8b908="" class="logo el-col el-col-4">
|
||||
<a data-v-74e8b908="" href="/" class="router-link-exact-active router-link-active">
|
||||
<img data-v-74e8b908="" src="/static/img/logo.png" class="logo-image">
|
||||
<span data-v-74e8b908="" class="logo-title">Scrape</span>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" id="index">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-18 el-col-offset-3">
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/1"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/ce4da3e03e655b5b88ed31b5cd7896cf62472.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/1" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">霸王别姬 - Farewell My Concubine</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">中国内地、中国香港</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">171 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1993-07-26 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/2"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p1.meituan.net/movie/6bea9af4524dfbd0b668eaa7e187c3df767253.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/2" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">这个杀手不太冷 - Léon</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>动作</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>犯罪</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">法国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">110 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1994-09-14 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/3"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/283292171619cdfd5b240c8fd093f1eb255670.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/3" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">肖申克的救赎 - The Shawshank Redemption</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>犯罪</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">142 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1994-09-10 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/4"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p1.meituan.net/movie/b607fba7513e7f15eab170aac1e1400d878112.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/4" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">泰坦尼克号 - Titanic</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>灾难</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">194 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1998-04-03 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/5"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/289f98ceaa8a0ae737d3dc01cd05ab052213631.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/5" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">罗马假日 - Roman Holiday</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>喜剧</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">118 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1953-08-20 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/6"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/da64660f82b98cdc1b8a3804e69609e041108.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/6" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">唐伯虎点秋香 - Flirting Scholar</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>喜剧</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>古装</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">中国香港</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">102 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1993-07-01 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/7"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/223c3e186db3ab4ea3bb14508c709400427933.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/7" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">乱世佳人 - Gone with the Wind</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>历史</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>战争</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">238 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1939-12-15 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/8"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/1f0d671f6a37f9d7b015e4682b8b113e174332.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/8" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">喜剧之王 - The King of Comedy</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>喜剧</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>爱情</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">中国香港</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">85 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1999-02-13 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.5</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 75.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/9"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/8959888ee0c399b0fe53a714bc8a5a17460048.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/9" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">楚门的世界 - The Truman Show</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>剧情</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>科幻</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">103 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.0</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 50.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div data-v-7f856186="" class="el-card item m-t is-hover-shadow">
|
||||
<div class="el-card__body">
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-8 el-col-sm-6 el-col-md-4">
|
||||
<a data-v-7f856186=""
|
||||
href="/detail/10"
|
||||
class="">
|
||||
<img
|
||||
data-v-7f856186=""
|
||||
src="https://p0.meituan.net/movie/27b76fe6cf3903f3d74963f70786001e1438406.jpg@464w_644h_1e_1c"
|
||||
class="cover">
|
||||
</a>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="p-h el-col el-col-24 el-col-xs-9 el-col-sm-13 el-col-md-16">
|
||||
<a data-v-7f856186="" href="/detail/10" class="name">
|
||||
<h2 data-v-7f856186="" class="m-b-sm">狮子王 - The Lion King</h2>
|
||||
</a>
|
||||
<div data-v-7f856186="" class="categories">
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>动画</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>歌舞</span>
|
||||
</button>
|
||||
|
||||
<button data-v-7f856186="" type="button"
|
||||
class="el-button category el-button--primary el-button--mini">
|
||||
<span>冒险</span>
|
||||
</button>
|
||||
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
<span data-v-7f856186="">美国</span>
|
||||
<span data-v-7f856186=""> / </span>
|
||||
<span data-v-7f856186="">89 分钟</span>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="m-v-sm info">
|
||||
|
||||
<span data-v-7f856186="">1995-07-15 上映</span>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-col el-col-24 el-col-xs-5 el-col-sm-5 el-col-md-4">
|
||||
<p data-v-7f856186=""
|
||||
class="score m-t-md m-b-n-sm">
|
||||
9.0</p>
|
||||
<p data-v-7f856186="">
|
||||
<div data-v-7f856186="" role="slider" aria-valuenow="4.75" aria-valuetext=""
|
||||
aria-valuemin="0"
|
||||
aria-valuemax="5" tabindex="0" class="el-rate">
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on" style="color: rgb(247, 186, 42);"></i>
|
||||
</span>
|
||||
|
||||
<span
|
||||
class="el-rate__item" style="cursor: auto;"><i
|
||||
class="el-rate__icon el-icon-star-on"
|
||||
style="color: rgb(239, 242, 247);"><i
|
||||
class="el-rate__decimal el-icon-star-on"
|
||||
style="color: rgb(247, 186, 42); width: 50.0%;"></i></i>
|
||||
</span>
|
||||
</div>
|
||||
</p></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class="el-loading-mask" style="display: none;">
|
||||
<div class="el-loading-spinner">
|
||||
<svg viewBox="25 25 50 50" class="circular">
|
||||
<circle cx="50" cy="50" r="20" fill="none" class="path"></circle>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div data-v-7f856186="" class="el-row">
|
||||
<div data-v-7f856186="" class="el-col el-col-10 el-col-offset-11">
|
||||
<div data-v-7f856186="" class="pagination m-v-lg">
|
||||
<div data-v-7f856186="" class="el-pagination is-background">
|
||||
<span class="el-pagination__total">共 102 条</span>
|
||||
|
||||
<button type="button" disabled="disabled" class="btn-prev">
|
||||
<i class="el-icon el-icon-arrow-left"></i>
|
||||
</button>
|
||||
|
||||
<ul class="el-pager">
|
||||
|
||||
|
||||
<li class="number active">
|
||||
<a href="/page/1">1</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/2">2</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/3">3</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/4">4</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/5">5</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/6">6</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/7">7</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/8">8</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/9">9</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/10">10</a>
|
||||
</li>
|
||||
|
||||
|
||||
<li class="number">
|
||||
<a href="/page/11">11</a>
|
||||
</li>
|
||||
|
||||
|
||||
</ul>
|
||||
|
||||
<a href="/page/2" class="next">
|
||||
<button type="button" class="btn-next"><i class="el-icon el-icon-arrow-right"></i></button>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</body>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 17:30
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/7 17:31
|
||||
@Usage :
|
||||
@Desc : 正则匹配re库基本使用
|
||||
'''
|
||||
|
||||
import re
|
||||
|
||||
'''
|
||||
基本使用
|
||||
'''
|
||||
|
||||
|
||||
def baseUse():
|
||||
content = 'hello 123456789 World_This is a Reges Demo'
|
||||
pattern = '^hello\s(\d+)\sWorld'
|
||||
result = re.match(pattern, content)
|
||||
|
||||
print(result) # <re.Match object; span=(0, 21), match='hello 123456789 World'>
|
||||
print(result.group()) # hello 123456789 World 输出匹配到的内容
|
||||
print(result.group(1)) # 123456789 输出第一个被()包围的匹配结果
|
||||
print(result.span()) # (0, 21) 输出匹配的范围
|
||||
|
||||
|
||||
# 高级用法
|
||||
'''
|
||||
贪婪匹配与非贪婪匹配
|
||||
.*表示尽可能多匹配字符
|
||||
.*?表示尽可能少匹配字符
|
||||
'''
|
||||
|
||||
'''
|
||||
re.I 表示匹配对大小写不明感
|
||||
re.L 实现本地化识别匹配
|
||||
re.M 表示多航匹配,影响^和$
|
||||
re.S 表示匹配内容包括换行符在内的所有字符
|
||||
re.U 表示根据Unicode解析字符,这个表示会影响\w,\W,\b和\B
|
||||
re.S 表示匹配内容包括换行符在内的所有字符
|
||||
'''
|
||||
def prior():
|
||||
content = '''hello 123456789 World_This
|
||||
is a Reges Demo
|
||||
'''
|
||||
result = re.match('^he.*?(\d+).*?Demo$', content)
|
||||
print(result.group(1)) # 未匹配到,报错 AttributeError: 'NoneType' object has no attribute 'group'
|
||||
result = re.match('^he.*?(\d+).*?Demo$', content, re.S)
|
||||
print(result.group(1)) # 123456789
|
||||
|
||||
|
||||
'''
|
||||
search:模糊匹配
|
||||
'''
|
||||
def search():
|
||||
content = 'Extra stings Hello 1234567 World_This is a Regex Demo Extra stings'
|
||||
result = re.match('Hello.*?(\d+).*?Demo', content) # 必须要以Hello开头才能匹配到
|
||||
print(result) # None
|
||||
result = re.search('Hello.*?(\d+).*?Demo', content)
|
||||
print(result) # <re.Match object; span=(13, 53), match='Hello 1234567 World_This is a Regex Demo'>
|
||||
|
||||
|
||||
def searchHtml():
|
||||
html = '''<div id="songs-list">
|
||||
<h2 class="title">经典老歌</h2>
|
||||
<p class="introduction">
|
||||
经典老歌列表
|
||||
</p>
|
||||
<ul id="list" class="list-group">
|
||||
<li data-view="2">一路上有你</li>
|
||||
<li data-view="7">
|
||||
<a href="/2.mp3" singer="任贤齐">沧海一声笑</a>
|
||||
</li>
|
||||
<li data-view="4" class="active">
|
||||
<a href="/3.mp3" singer="齐秦">往事随风</a>
|
||||
</li>
|
||||
<li data-view="6"><a href="/4.mp3" singer="beyond">光辉岁月</a></li>
|
||||
<li data-view="5"><a href="/5.mp3" singer="陈慧琳">记事本</a></li>
|
||||
<li data-view="5">
|
||||
<a href="/6.mp3" singer="邓丽君">但愿人长久</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>'''
|
||||
result = re.search('<li.*?active.*?singer="(.*?)">(.*?)</a>', html, re.S)
|
||||
if result:
|
||||
print(result.group(1), result.group(2))
|
||||
|
||||
result = re.search('<li.*?singer="(.*?)">(.*?)</a>', html, re.S)
|
||||
if result:
|
||||
print(result.group(1), result.group(2))
|
||||
|
||||
result = re.search('<li.*?singer="(.*?)">(.*?)</a>', html)
|
||||
if result:
|
||||
print(result.group(1), result.group(2))
|
||||
|
||||
'''
|
||||
findAll:找到所有匹配的
|
||||
'''
|
||||
def findall():
|
||||
html = '''<div id="songs-list">
|
||||
<h2 class="title">经典老歌</h2>
|
||||
<p class="introduction">
|
||||
经典老歌列表
|
||||
</p>
|
||||
<ul id="list" class="list-group">
|
||||
<li data-view="2">一路上有你</li>
|
||||
<li data-view="7">
|
||||
<a href="/2.mp3" singer="任贤齐">沧海一声笑</a>
|
||||
</li>
|
||||
<li data-view="4" class="active">
|
||||
<a href="/3.mp3" singer="齐秦">往事随风</a>
|
||||
</li>
|
||||
<li data-view="6"><a href="/4.mp3" singer="beyond">光辉岁月</a></li>
|
||||
<li data-view="5"><a href="/5.mp3" singer="陈慧琳">记事本</a></li>
|
||||
<li data-view="5">
|
||||
<a href="/6.mp3" singer="邓丽君">但愿人长久</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>'''
|
||||
results = re.findall('<li.*?href="(.*?)".*?singer="(.*?)">(.*?)</a>', html, re.S)
|
||||
print(results)
|
||||
print(type(results))
|
||||
for result in results:
|
||||
print(result)
|
||||
print(result[0], result[1], result[2])
|
||||
|
||||
results = re.findall('<li.*?>\s*?(<a.*?>)?(\w+)(</a>)?\s*?</li>', html, re.S)
|
||||
for result in results:
|
||||
print(result[1])
|
||||
|
||||
|
||||
'''
|
||||
sub:正则匹配修改文本
|
||||
使用正则匹配,去除掉能匹配上的内容
|
||||
'''
|
||||
def sub():
|
||||
html = '''<div id="songs-list">
|
||||
<h2 class="title">经典老歌</h2>
|
||||
<p class="introduction">
|
||||
经典老歌列表
|
||||
</p>
|
||||
<ul id="list" class="list-group">
|
||||
<li data-view="2">一路上有你</li>
|
||||
<li data-view="7">
|
||||
<a href="/2.mp3" singer="任贤齐">沧海一声笑</a>
|
||||
</li>
|
||||
<li data-view="4" class="active">
|
||||
<a href="/3.mp3" singer="齐秦">往事随风</a>
|
||||
</li>
|
||||
<li data-view="6"><a href="/4.mp3" singer="beyond">光辉岁月</a></li>
|
||||
<li data-view="5"><a href="/5.mp3" singer="陈慧琳">记事本</a></li>
|
||||
<li data-view="5">
|
||||
<a href="/6.mp3" singer="邓丽君">但愿人长久</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>'''
|
||||
html = re.sub('<a.*?>|</a>', '', html)
|
||||
print(html)
|
||||
results = re.findall('<li.*?>(.*?)</li>', html, re.S)
|
||||
for result in results:
|
||||
print(result.strip())
|
||||
|
||||
'''
|
||||
compile:编译
|
||||
将相关模式编译成pattern对象,进行复用
|
||||
'''
|
||||
def complie():
|
||||
content1 = '2019-12-15 12:00'
|
||||
content2 = '2019-12-17 12:55'
|
||||
content3 = '2019-12-22 13:21'
|
||||
pattern = re.compile('\d{2}:\d{2}')
|
||||
result1 = re.sub(pattern, '', content1)
|
||||
result2 = re.sub(pattern, '', content2)
|
||||
result3 = re.sub(pattern, '', content3)
|
||||
print(result1, result2, result3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sub()
|
||||
|
|
@ -27,7 +27,7 @@ batch_size = 32
|
|||
EPOCH = 1000
|
||||
unit = 512 # LSTM的维度
|
||||
predict_num = 50 # 预测个数
|
||||
model_name = "LSTM"
|
||||
model_name = "cnn_LSTM"
|
||||
save_name = r"self_{0}_hidden{1}_unit{2}_feature{3}_predict{4}.h5".format(model_name, hidden_num, unit, feature,
|
||||
predict_num)
|
||||
|
||||
|
|
@ -125,7 +125,8 @@ def predict_model(filter_num, dims):
|
|||
# LSTM = tf.keras.layers.LSTM(units=256, return_sequences=False)(LSTM)
|
||||
|
||||
#### 自己
|
||||
LSTM = LSTMLayer(units=512, return_sequences=True)(input)
|
||||
LSTM = tf.keras.layers.Conv1D(512, kernel_size=8, padding='same')(input)
|
||||
LSTM = LSTMLayer(units=512, return_sequences=True)(LSTM)
|
||||
LSTM = LSTMLayer(units=256, return_sequences=False)(LSTM)
|
||||
|
||||
x = tf.keras.layers.Dense(128, activation="relu")(LSTM)
|
||||
|
|
@ -199,8 +200,7 @@ if __name__ == '__main__':
|
|||
# model.summary()
|
||||
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
||||
#
|
||||
# history = model.fit(train_data, train_label_single, epochs=EPOCH,
|
||||
# batch_size=batch_size, validation_data=(val_data, val_label_single), shuffle=True, verbose=2,
|
||||
# history = model.fit(train_data, train_label_single, epochs=EPOCH, validation_data=(val_data, val_label_single), shuffle=True, verbose=1,
|
||||
# callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
#### TODO 测试
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/6 13:12
|
||||
@Usage :
|
||||
@Desc : 画loss
|
||||
'''
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.pyplot import rcParams
|
||||
|
||||
import lttb
|
||||
|
||||
file_name = "E:\论文写作\ISA论文返修\loss.txt"
|
||||
|
||||
|
||||
def aux_loss(file_name="E:\论文写作\ISA论文返修\loss.txt"):
|
||||
mse1_list = []
|
||||
mse2_list = []
|
||||
mse3_list = []
|
||||
total_list = []
|
||||
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
i = 0
|
||||
for ann in f.readlines():
|
||||
ann = ann.strip('\n') # 去除文本中的换行符
|
||||
if i % 4 == 0:
|
||||
# print(ann)
|
||||
mse1 = float(ann.split(" ")[1])
|
||||
mse1_list.append(mse1)
|
||||
elif i % 4 == 1:
|
||||
mse2 = float(ann.split(" ")[1])
|
||||
mse2_list.append(mse2)
|
||||
elif i % 4 == 2:
|
||||
mse3 = float(ann.split(" ")[1])
|
||||
mse3_list.append(mse3)
|
||||
elif i % 4 == 3:
|
||||
total = float(ann.split(" ")[5])
|
||||
total_list.append(total)
|
||||
i += 1
|
||||
|
||||
mse1_list = lttb.downsample(np.array([range(len(mse1_list)), mse1_list]).T, n_out=50)
|
||||
mse2_list = lttb.downsample(np.array([range(len(mse2_list)), mse2_list]).T, n_out=50)
|
||||
mse3_list = lttb.downsample(np.array([range(len(mse3_list)), mse3_list]).T, n_out=50)
|
||||
total_list = lttb.downsample(np.array([range(len(total_list)), total_list]).T, n_out=50)
|
||||
|
||||
plot_aux_loss(mse1_list, mse2_list, mse3_list)
|
||||
|
||||
|
||||
def plot_aux_loss(mse1_list, mse2_list, mse3_list):
|
||||
config = {
|
||||
"font.family": 'Times New Roman', # 设置字体类型
|
||||
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
||||
"axes.labelsize": 13
|
||||
}
|
||||
rcParams.update(config)
|
||||
pic1 = plt.figure(figsize=(8, 3), dpi=200)
|
||||
# plt.ylim(0, 0.6) # 设置y轴范围
|
||||
plt.plot(mse1_list[:, 1], label='Predict Head 1 Loss')
|
||||
plt.plot(mse2_list[:, 1], label='Predict Head 2 Loss')
|
||||
plt.plot(mse3_list[:, 1], label='Predict Head 3 Loss')
|
||||
plt.title('Training loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('loss')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(alpha=0.8)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def main_loss(file_name="E:\论文写作\ISA论文返修\main_loss.txt"):
|
||||
mse1_list = []
|
||||
cross_Entropy_list = []
|
||||
total_list = []
|
||||
|
||||
with open(file_name, 'r', encoding='utf-8') as f:
|
||||
i = 0
|
||||
for ann in f.readlines():
|
||||
ann = ann.strip('\n') # 去除文本中的换行符
|
||||
if i % 3 == 0:
|
||||
# print(ann)
|
||||
mse1 = float(ann.split(" ")[1])
|
||||
mse1_list.append(mse1)
|
||||
elif i % 3 == 1:
|
||||
cross_Entropy = float(ann.split(" ")[1])*0.02
|
||||
cross_Entropy_list.append(cross_Entropy)
|
||||
elif i % 3 == 2:
|
||||
total = float(ann.split(" ")[5])
|
||||
total_list.append(total)
|
||||
i += 1
|
||||
|
||||
mse1_list = lttb.downsample(np.array([range(len(mse1_list)), mse1_list]).T, n_out=40)
|
||||
cross_Entropy_list = lttb.downsample(np.array([range(len(cross_Entropy_list)), cross_Entropy_list]).T, n_out=40)
|
||||
total_list = lttb.downsample(np.array([range(len(total_list)), total_list]).T, n_out=40)
|
||||
|
||||
plot_main_loss(mse1_list, cross_Entropy_list,total_list)
|
||||
|
||||
# 0.014995 0.016387
|
||||
# 0.014384 0.015866
|
||||
# 0.014985 0.014261
|
||||
# 0.013008 0.01349
|
||||
# 0.013285 0.013139
|
||||
# 0.012999 0.012714
|
||||
# 0.011451 0.012477
|
||||
# 0.010055 0.012471
|
||||
# 0.010303 0.014595
|
||||
def plot_main_loss(mse1_list, mse2_list,total_loss):
|
||||
config = {
|
||||
"font.family": 'Times New Roman', # 设置字体类型
|
||||
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
||||
"axes.labelsize": 13
|
||||
}
|
||||
rcParams.update(config)
|
||||
pic1 = plt.figure(figsize=(8, 3), dpi=200)
|
||||
# plt.ylim(0, 0.6) # 设置y轴范围
|
||||
plt.plot(mse1_list[:, 1], label='${L_{diff}}$')
|
||||
plt.plot(mse2_list[:, 1], label='${L_{correct}}$')
|
||||
plt.plot(total_loss[:, 1], label='${L_{main}}$')
|
||||
plt.title('Training loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('loss')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(alpha=0.8)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_loss()
|
||||
|
|
@ -56,12 +56,12 @@ def plot_result_banda(result_data):
|
|||
fig, ax = plt.subplots(1, 1)
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal','size':7} # 设置坐标标签的字体大小,字体
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axvline(result_data.shape[0] * 2 / 3 - 50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
|
|
@ -85,8 +85,8 @@ def plot_result_banda(result_data):
|
|||
# pad调整label与坐标轴之间的距离
|
||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||
# plt.yticks([index for index in indices1], classes1)
|
||||
plt.ylabel('Confidence',fontdict=font2)
|
||||
plt.xlabel('Time',fontdict=font2)
|
||||
plt.ylabel('Confidence', fontdict=font2)
|
||||
plt.xlabel('Time', fontdict=font2)
|
||||
plt.tight_layout()
|
||||
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
||||
plt.legend(loc='upper right', frameon=False, fontsize=4.5)
|
||||
|
|
@ -98,11 +98,11 @@ def plot_result_banda(result_data):
|
|||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||
bbox_transform=ax.transAxes)
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
axins.axvline(result_data.shape[0] * 2 / 3 - 50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||
# 设置放大区间
|
||||
# 设置放大区间
|
||||
zone_left = int(result_data.shape[0] * 2 / 3 -160)
|
||||
zone_left = int(result_data.shape[0] * 2 / 3 - 160)
|
||||
zone_right = int(result_data.shape[0] * 2 / 3) + 40
|
||||
# zone_left = int(result_data.shape[0] * 2 / 3 +250)
|
||||
# zone_right = int(result_data.shape[0] * 2 / 3) + 450
|
||||
|
|
@ -132,6 +132,7 @@ def plot_result_banda(result_data):
|
|||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def plot_result(result_data):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
|
|
@ -153,8 +154,7 @@ def plot_result(result_data):
|
|||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axvline(result_data.shape[0] * 2 / 3 - 15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
|
|
@ -188,11 +188,11 @@ def plot_result(result_data):
|
|||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||
bbox_transform=ax.transAxes)
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||
axins.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
axins.axvline(result_data.shape[0] * 2 / 3 - 15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||
# 设置放大区间
|
||||
# 设置放大区间
|
||||
zone_left = int(result_data.shape[0] * 2 / 3 -100)
|
||||
zone_left = int(result_data.shape[0] * 2 / 3 - 100)
|
||||
zone_right = int(result_data.shape[0] * 2 / 3) + 100
|
||||
x = list(range(result_data.shape[0]))
|
||||
|
||||
|
|
@ -221,6 +221,25 @@ def plot_result(result_data):
|
|||
pass
|
||||
|
||||
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
length, = data.shape
|
||||
t = 0
|
||||
data = pd.DataFrame(data).ewm(alpha=namuda).mean()
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return data, np.broadcast_to(mid, shape=[length, ]), np.broadcast_to(UCL, shape=[length, ]), np.broadcast_to(LCL,
|
||||
shape=[
|
||||
length, ])
|
||||
pass
|
||||
|
||||
|
||||
def plot_MSE(total_MSE, total_max):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
|
|
@ -251,12 +270,13 @@ def plot_MSE(total_MSE, total_max):
|
|||
classes = ['01/09/17', '08/09/17', '15/09/17', '22/09/17', '29/09/17']
|
||||
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.ylabel('Mse', fontsize=5)
|
||||
plt.ylabel('MSE', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.tight_layout()
|
||||
|
||||
plt.plot(total_max, "--", label="max", linewidth=0.5)
|
||||
plt.plot(total_MSE, label="mse", linewidth=0.5, color='purple')
|
||||
plt.plot(total_max, "--", label="Threshold", linewidth=0.5, color='red')
|
||||
plt.plot(total_max, "--", label="Threshold", linewidth=0.5, color='red')
|
||||
plt.plot(total_MSE, label="MSE", linewidth=0.5, color='purple')
|
||||
plt.legend(loc='best', frameon=False, fontsize=5)
|
||||
|
||||
# plt.plot(total_mean)
|
||||
|
|
@ -265,6 +285,54 @@ def plot_MSE(total_MSE, total_max):
|
|||
pass
|
||||
|
||||
|
||||
def plot_EWMA(model_name, total_MSE, total_mid, total_max, total_min):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (2.7, 2),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 5,
|
||||
'ytick.labelsize': 5,
|
||||
'legend.fontsize': 5,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
|
||||
result_data = total_MSE
|
||||
# 画出 y=1 这条水平线
|
||||
# plt.axhline(0.5, c='red', label='Failure threshold',lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.', lw=0.5, label="real fault")
|
||||
|
||||
indices = [result_data.shape[0] * i / 4 for i in range(5)]
|
||||
classes = ['09/01', '09/08', '09/15', '09/22', '09/29/']
|
||||
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.ylabel('MSE', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.tight_layout()
|
||||
|
||||
plt.plot(total_min, "--", label="LCL", linewidth=0.5, color='red')
|
||||
plt.plot(total_mid, "--", label="mean(X)", linewidth=0.5)
|
||||
plt.plot(total_max, "--", label="UCL", linewidth=0.5, color='red')
|
||||
plt.plot(total_MSE, label="MSE", linewidth=0.5, color='purple')
|
||||
plt.legend(loc='best', frameon=False, fontsize=5)
|
||||
|
||||
# plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
|
||||
plt.savefig('E:\论文写作\ISA论文返修\图片\比较方法的EWMD/{0}.png'.format(model_name))
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def plot_Corr(data, size: int = 1):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
|
|
@ -333,16 +401,21 @@ def plot_bar(y_data):
|
|||
|
||||
plt.bar(x_width[1], y_data[1], lw=1, color=['#F5E3C4'], width=0.5, label="GRU", edgecolor='black')
|
||||
plt.bar(x_width[2], y_data[2], lw=1, color=['#EBC99D'], width=0.5, label="CNN-GRU", edgecolor='black')
|
||||
plt.bar(x_width[3], y_data[3], lw=1, color=['#FFC79C'], width=0.5, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[4], lw=1, color=['#BEE9C7'], width=0.5, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[5], y_data[5], lw=1, color=['#B8E9D0'], width=0.5, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[6], lw=1, color=['#B9E9E2'], width=0.5, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[7], lw=1, color=['#D6E6F2'], width=0.5, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[8], lw=1, color=['#B4D1E9'], width=0.5, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[9], lw=1, color=['#AEB5EE'], width=0.5, label="RNet-D", edgecolor='black')
|
||||
plt.bar(x_width[10], y_data[10], lw=1, color=['#D2D3FC'], width=0.5, label="ResNet-18", edgecolor='black')
|
||||
plt.bar(x_width[11], y_data[11], lw=1, color=['#D5A9FF'], width=0.5, label="ResNet-C", edgecolor='black')
|
||||
plt.bar(x_width[12], y_data[12], lw=1, color=['#E000F5'], width=0.5, label="JMNet", edgecolor='black')
|
||||
|
||||
plt.bar(x_width[3], y_data[13], lw=1, color=['#FCD58B'], width=0.5, label="CG-QA", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[14], lw=1, color=['#FFBA75'], width=0.5, label="CG-3{0}".format(chr(963)), edgecolor='black')
|
||||
|
||||
plt.bar(x_width[5], y_data[3], lw=1, color=['#FFC79C'], width=0.5, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[4], lw=1, color=['#BEE9C7'], width=0.5, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[5], lw=1, color=['#B8E9D0'], width=0.5, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[6], lw=1, color=['#B9E9E2'], width=0.5, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[7], lw=1, color=['#D6E6F2'], width=0.5, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[10], y_data[8], lw=1, color=['#B4D1E9'], width=0.5, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[11], y_data[9], lw=1, color=['#AEB5EE'], width=0.5, label="RNet-D", edgecolor='black')
|
||||
plt.bar(x_width[12], y_data[10], lw=1, color=['#D2D3FC'], width=0.5, label="ResNet-18", edgecolor='black')
|
||||
plt.bar(x_width[13], y_data[11], lw=1, color=['#D5A9FF'], width=0.5, label="ResNet-C", edgecolor='black')
|
||||
plt.bar(x_width[14], y_data[12], lw=1, color=['#E000F5'], width=0.5, label="JRFPN", edgecolor='black')
|
||||
|
||||
|
||||
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
|
||||
plt.xticks([])
|
||||
|
|
@ -350,14 +423,14 @@ def plot_bar(y_data):
|
|||
plt.xlabel('Methods', fontsize=22)
|
||||
# plt.tight_layout()
|
||||
|
||||
num1, num2, num3, num4 = 0, 1, 3, 0
|
||||
num1, num2, num3, num4 = 0.015, 1, 3, 0
|
||||
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, handlelength=1,
|
||||
handletextpad=0.45, columnspacing=1)
|
||||
plt.ylim([-0.01, 5])
|
||||
plt.show()
|
||||
|
||||
|
||||
def acc(y_data=list):
|
||||
def acc(y_data: list):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (10, 6),
|
||||
|
|
@ -373,22 +446,31 @@ def acc(y_data=list):
|
|||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
|
||||
x_width = [i / 2 for i in range(0, len(y_data))]
|
||||
x_width = [i for i in range(0, len(y_data))]
|
||||
# x2_width = [i + 0.3 for i in x_width]
|
||||
|
||||
plt.bar(x_width[0], y_data[0], lw=1, color=['#FAF4E1'], width=0.25, label="CNN", edgecolor='black')
|
||||
plt.bar(x_width[1], y_data[1], lw=1, color=['#F5E3C4'], width=0.25, label="GRU", edgecolor='black')
|
||||
plt.bar(x_width[2], y_data[2], lw=1, color=['#EBC99D'], width=0.25, label="CNN-GRU", edgecolor='black')
|
||||
plt.bar(x_width[3], y_data[3], lw=1, color=['#FFC79C'], width=0.25, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[4], lw=1, color=['#BEE9C7'], width=0.25, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[5], y_data[5], lw=1, color=['#B8E9D0'], width=0.25, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[6], lw=1, color=['#B9E9E2'], width=0.25, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[7], lw=1, color=['#D6E6F2'], width=0.25, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[8], lw=1, color=['#B4D1E9'], width=0.25, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[9], lw=1, color=['#AEB5EE'], width=0.25, label="RNet-D", edgecolor='black')
|
||||
plt.bar(x_width[10], y_data[10], lw=1, color=['#D2D3FC'], width=0.25, label="ResNet-18", edgecolor='black')
|
||||
plt.bar(x_width[11], y_data[11], lw=1, color=['#D5A9FF'], width=0.25, label="ResNet-C", edgecolor='black')
|
||||
plt.bar(x_width[12], y_data[12], lw=1, color=['#E000F5'], width=0.25, label="JMNet", edgecolor='black')
|
||||
_width = [i for i in range(0, len(y_data))]
|
||||
# x2_width = [i + 0.3 for i in x_width]
|
||||
|
||||
plt.bar(x_width[0], y_data[0], lw=1, color=['#FAF4E1'], width=0.5, label="CNN", edgecolor='black')
|
||||
|
||||
plt.bar(x_width[1], y_data[1], lw=1, color=['#F5E3C4'], width=0.5, label="GRU", edgecolor='black')
|
||||
plt.bar(x_width[2], y_data[2], lw=1, color=['#EBC99D'], width=0.5, label="CNN-GRU", edgecolor='black')
|
||||
|
||||
plt.bar(x_width[3], y_data[13], lw=1, color=['#FCD58B'], width=0.5, label="CG-QA", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[14], lw=1, color=['#FFBA75'], width=0.5, label="CG-3{0}".format(chr(963)),
|
||||
edgecolor='black')
|
||||
|
||||
plt.bar(x_width[5], y_data[3], lw=1, color=['#FFC79C'], width=0.5, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[4], lw=1, color=['#BEE9C7'], width=0.5, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[5], lw=1, color=['#B8E9D0'], width=0.5, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[6], lw=1, color=['#B9E9E2'], width=0.5, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[7], lw=1, color=['#D6E6F2'], width=0.5, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[10], y_data[8], lw=1, color=['#B4D1E9'], width=0.5, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[11], y_data[9], lw=1, color=['#AEB5EE'], width=0.5, label="RNet-D", edgecolor='black')
|
||||
plt.bar(x_width[12], y_data[10], lw=1, color=['#D2D3FC'], width=0.5, label="ResNet-18", edgecolor='black')
|
||||
plt.bar(x_width[13], y_data[11], lw=1, color=['#D5A9FF'], width=0.5, label="ResNet-C", edgecolor='black')
|
||||
plt.bar(x_width[14], y_data[12], lw=1, color=['#E000F5'], width=0.5, label="JRFPN", edgecolor='black')
|
||||
|
||||
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
|
||||
plt.xticks([])
|
||||
|
|
@ -397,7 +479,8 @@ def acc(y_data=list):
|
|||
# plt.tight_layout()
|
||||
|
||||
num1, num2, num3, num4 = 0, 1, 3, 0
|
||||
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False,handlelength=1,handletextpad=0.45,columnspacing=1)
|
||||
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, handlelength=1,
|
||||
handletextpad=0.45, columnspacing=1)
|
||||
plt.ylim([60, 105])
|
||||
plt.show()
|
||||
|
||||
|
|
@ -424,13 +507,18 @@ def plot_FNR1(y_data):
|
|||
plt.bar(x_width[0], y_data[0], lw=1, color=['#FAF4E1'], width=0.5 * 5 / 6, label="CNN", edgecolor='black')
|
||||
plt.bar(x_width[1], y_data[1], lw=1, color=['#F5E3C4'], width=0.5 * 5 / 6, label="GRU", edgecolor='black')
|
||||
plt.bar(x_width[2], y_data[2], lw=1, color=['#EBC99D'], width=0.5 * 5 / 6, label="CNN-GRU", edgecolor='black')
|
||||
plt.bar(x_width[3], y_data[3], lw=1, color=['#FFC79C'], width=0.5 * 5 / 6, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[4], lw=1, color=['#BEE9C7'], width=0.5 * 5 / 6, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[5], y_data[5], lw=1, color=['#B8E9D0'], width=0.5 * 5 / 6, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[6], lw=1, color=['#B9E9E2'], width=0.5 * 5 / 6, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[7], lw=1, color=['#D6E6F2'], width=0.5 * 5 / 6, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[8], lw=1, color=['#B4D1E9'], width=0.5 * 5 / 6, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[9], lw=1, color=['#AEB5EE'], width=0.5 * 5 / 6, label="RNet-D", edgecolor='black')
|
||||
|
||||
plt.bar(x_width[3], y_data[10], lw=1, color=['#FCD58B'], width=0.5 * 5 / 6, label="CG-QA", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[11], lw=1, color=['#EBC99D'], width=0.5 * 5 / 6, label="CG-3{0}".format(chr(963)), edgecolor='black')
|
||||
|
||||
plt.bar(x_width[5], y_data[3], lw=1, color=['#FFC79C'], width=0.5 * 5 / 6, label="DCConv", edgecolor='black')
|
||||
plt.bar(x_width[6], y_data[4], lw=1, color=['#BEE9C7'], width=0.5 * 5 / 6, label="RepDCConv", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[5], lw=1, color=['#B8E9D0'], width=0.5 * 5 / 6, label="RNet-MSE", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[6], lw=1, color=['#B9E9E2'], width=0.5 * 5 / 6, label="RNet", edgecolor='black')
|
||||
plt.bar(x_width[9], y_data[7], lw=1, color=['#D6E6F2'], width=0.5 * 5 / 6, label="RNet-SE", edgecolor='black')
|
||||
plt.bar(x_width[10], y_data[8], lw=1, color=['#B4D1E9'], width=0.5 * 5 / 6, label="RNet-L", edgecolor='black')
|
||||
plt.bar(x_width[11], y_data[9], lw=1, color=['#AEB5EE'], width=0.5 * 5 / 6, label="RNet-D", edgecolor='black')
|
||||
|
||||
|
||||
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
|
||||
plt.xticks([])
|
||||
|
|
@ -440,7 +528,7 @@ def plot_FNR1(y_data):
|
|||
plt.xlabel('Methods', fontsize=22)
|
||||
# plt.tight_layout()
|
||||
|
||||
num1, num2, num3, num4 = 0.05, 1, 3, 0
|
||||
num1, num2, num3, num4 = 0.025, 1, 3, 0
|
||||
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, handlelength=1,
|
||||
handletextpad=0.45, columnspacing=1)
|
||||
|
||||
|
|
@ -482,12 +570,22 @@ def plot_FNR2(y_data):
|
|||
plt.bar(x_width[1], y_data[1], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[2], y_data[2], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[3], y_data[3], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[4], y_data[4],lw=1, color=['#D5A9FF'], width=0.5 * 2 / 3, label="RNet-C", edgecolor='black')
|
||||
plt.bar(x_width[4], y_data[4], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[5], y_data[5], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[6], y_data[6], color=['#FFFFFF'], label=" ")
|
||||
# plt.bar(x_width[7] + 2.0, y_data[10], lw=0.5, color=['#8085e9'], width=1, label="ResNet-18", edgecolor='black')
|
||||
|
||||
|
||||
plt.bar(x_width[6], y_data[6], lw=1, color=['#D5A9FF'], width=0.5 * 2 / 3, label="RNet-C", edgecolor='black')
|
||||
plt.bar(x_width[7], y_data[7], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[8], y_data[8],lw=1, color=['#E000F5'], width=0.5 * 2 / 3, label="JMNet", edgecolor='black')
|
||||
plt.bar(x_width[8], y_data[8], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[9], y_data[9], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[10], y_data[10], color=['#FFFFFF'], label=" ")
|
||||
plt.bar(x_width[11], y_data[11], color=['#FFFFFF'], label=" ")
|
||||
|
||||
# plt.bar(x_width[7] + 2.0, y_data[10], lw=0.5, color=['#8085e9'], width=1, label="ResNet-18", edgecolor='black')
|
||||
|
||||
|
||||
plt.bar(x_width[12], y_data[12], lw=1, color=['#E000F5'], width=0.5 * 2 / 3, label="JRFPN", edgecolor='black')
|
||||
plt.bar(x_width[13], y_data[13], color=['#FFFFFF'], label=" ")
|
||||
|
||||
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
|
||||
plt.xticks([])
|
||||
|
|
@ -554,14 +652,14 @@ def plot_hot_one(data):
|
|||
pass
|
||||
|
||||
|
||||
def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
|
||||
def plot_mse(file_name_mse="../others_idea/mse", data: str = ''):
|
||||
mse = np.loadtxt(file_name_mse, delimiter=",")
|
||||
raw_data=np.loadtxt(data,delimiter=",")
|
||||
raw_data=raw_data[:,:mse.shape[1]]
|
||||
print("mse:",mse.shape)
|
||||
print("raw_data:",raw_data.shape)
|
||||
res=raw_data-mse
|
||||
mse.shape[0]*2/3
|
||||
raw_data = np.loadtxt(data, delimiter=",")
|
||||
raw_data = raw_data[:, :mse.shape[1]]
|
||||
print("mse:", mse.shape)
|
||||
print("raw_data:", raw_data.shape)
|
||||
res = raw_data - mse
|
||||
mse.shape[0] * 2 / 3
|
||||
# mse = mse[2000:2300]
|
||||
# mse = mse[1800:2150]
|
||||
|
||||
|
|
@ -589,20 +687,19 @@ def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
|
|||
plt.axvline(res.shape[1] * 2 / 3, c='purple', ls='-.', lw=0.5, label="real fault")
|
||||
plt.plot(res[i, :], lw=0.5)
|
||||
|
||||
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_result_predict1.csv"):
|
||||
mse = np.loadtxt(file_name_mse, delimiter=",")
|
||||
|
||||
print("mse:",mse.shape)
|
||||
print("mse:", mse.shape)
|
||||
|
||||
need_shape=int(mse.shape[0]*2/3)
|
||||
need_shape = int(mse.shape[0] * 2 / 3)
|
||||
# mse = mse[2000:2300]
|
||||
# mse = mse[1800:2150]
|
||||
# mse = mse[ need_shape+100:need_shape+377]
|
||||
mse = mse[ need_shape-300:need_shape-10]
|
||||
mse = mse[need_shape - 300:need_shape - 10]
|
||||
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
|
|
@ -613,7 +710,7 @@ def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_
|
|||
'xtick.labelsize': 6,
|
||||
'ytick.labelsize': 6,
|
||||
'legend.fontsize': 5,
|
||||
'font.family':'Times New Roman'
|
||||
'font.family': 'Times New Roman'
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||
|
|
@ -622,7 +719,7 @@ def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_
|
|||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
indices = [mse.shape[0] * i / 4 for i in range(5)]
|
||||
# classes = ['07:21','08:21', '09:21', '10:21', '11:21']
|
||||
classes = ['01:58','02:58', '03:58', '04:58', '05:58']
|
||||
classes = ['01:58', '02:58', '03:58', '04:58', '05:58']
|
||||
|
||||
plt.xticks([index for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# pad调整label与坐标轴之间的距离
|
||||
|
|
@ -632,8 +729,6 @@ def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_
|
|||
plt.tight_layout()
|
||||
plt.plot(mse[:], lw=0.5)
|
||||
|
||||
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
|
|
@ -641,37 +736,36 @@ def plot_3d():
|
|||
# 线
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
data,label = read_data(file_name='G:\data\SCADA数据\jb4q_8_delete_total_zero.csv',isNew=False)
|
||||
data, label = read_data(file_name='G:\data\SCADA数据\jb4q_8_delete_total_zero.csv', isNew=False)
|
||||
print(data)
|
||||
print(data.shape)
|
||||
x=range(data.shape[1])
|
||||
y=range(data.shape[0])
|
||||
color=[[0.16,0.14,0.13],
|
||||
[0.89,0.09,0.05],
|
||||
[0.12,0.56,1.00],
|
||||
[0.01,0.66,0.62],
|
||||
[0.63,0.04,0.83],
|
||||
[0.63,0.32,0.18],
|
||||
[0.20,0.63,0.79],
|
||||
[0.50,0.16,0.16],
|
||||
[0.61,0.4,0.12],
|
||||
[1.00,0.38,0.00],
|
||||
[0.53,0.81,0.92],
|
||||
[0.13,0.55,0.13],
|
||||
[1.00,0.89,0.52],
|
||||
[0.44,0.50,0.41],
|
||||
[0.20,0.63,0.79],
|
||||
[0.00,0.78,0.55],
|
||||
[1.00,0.39,0.28],
|
||||
[0.25,0.41,0.88]]
|
||||
x = range(data.shape[1])
|
||||
y = range(data.shape[0])
|
||||
color = [[0.16, 0.14, 0.13],
|
||||
[0.89, 0.09, 0.05],
|
||||
[0.12, 0.56, 1.00],
|
||||
[0.01, 0.66, 0.62],
|
||||
[0.63, 0.04, 0.83],
|
||||
[0.63, 0.32, 0.18],
|
||||
[0.20, 0.63, 0.79],
|
||||
[0.50, 0.16, 0.16],
|
||||
[0.61, 0.4, 0.12],
|
||||
[1.00, 0.38, 0.00],
|
||||
[0.53, 0.81, 0.92],
|
||||
[0.13, 0.55, 0.13],
|
||||
[1.00, 0.89, 0.52],
|
||||
[0.44, 0.50, 0.41],
|
||||
[0.20, 0.63, 0.79],
|
||||
[0.00, 0.78, 0.55],
|
||||
[1.00, 0.39, 0.28],
|
||||
[0.25, 0.41, 0.88]]
|
||||
# c颜色,marker:样式*雪花
|
||||
|
||||
ax.plot( xs=y,ys=x, zs=data, c=color, marker="*")
|
||||
ax.plot(xs=y, ys=x, zs=data, c=color, marker="*")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
def test_result(file_name: str = result_file_name):
|
||||
# result_data = np.recfromcsv(file_name)
|
||||
result_data = np.loadtxt(file_name, delimiter=",")
|
||||
|
|
@ -680,7 +774,7 @@ def test_result(file_name: str = result_file_name):
|
|||
theshold = len(result_data)
|
||||
print(theshold)
|
||||
print(theshold * 2 / 3)
|
||||
theshold = theshold * 2 / 3-50
|
||||
theshold = theshold * 2 / 3 - 50
|
||||
# 计算误报率和漏报率
|
||||
positive_rate = result_data[:int(theshold)][result_data[:int(theshold)] < 0.66].__len__() / (
|
||||
theshold * 2 / 3)
|
||||
|
|
@ -783,22 +877,166 @@ def test_model_visualization(model_name=file_name):
|
|||
|
||||
plot_hot(needed_data)
|
||||
|
||||
def plot_loss():
|
||||
from matplotlib import rcParams
|
||||
|
||||
config = {
|
||||
"font.family": 'Times New Roman', # 设置字体类型
|
||||
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
||||
"axes.labelsize": 13
|
||||
}
|
||||
rcParams.update(config)
|
||||
|
||||
pic1 = plt.figure(figsize=(8, 6), dpi=200)
|
||||
plt.subplot(211)
|
||||
plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc')
|
||||
plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc')
|
||||
plt.ylim(0.3, 1.0) # 设置y轴范围
|
||||
plt.title('Training & Validation accuracy')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('accuracy')
|
||||
plt.legend(loc='lower right')
|
||||
plt.grid(alpha=0.4)
|
||||
|
||||
plt.subplot(212)
|
||||
plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss')
|
||||
plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss')
|
||||
plt.ylim(0, 0.08) # 设置y轴范围
|
||||
plt.title('Training & Validation loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('loss')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(alpha=0.4)
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = int(time.time())
|
||||
|
||||
# 将时间戳转换为字符串
|
||||
timestamp_str = str(timestamp)
|
||||
plt.savefig(timestamp_str, dpi=200)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model_list = ["DCConv", "GRU"]
|
||||
#
|
||||
# for model_name in model_list:
|
||||
# # model_name = "CNN_GRU"
|
||||
# mse = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}/banda\mse.csv'.format(
|
||||
# model_name),
|
||||
# delimiter=',')
|
||||
# max = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}/banda\max.csv'.format(
|
||||
# model_name),
|
||||
# delimiter=',')
|
||||
# data, mid, UCL, LCL = EWMA(mse)
|
||||
#
|
||||
# plot_EWMA(
|
||||
# model_name=model_name+"_banda",
|
||||
# total_MSE=data,
|
||||
# total_mid=mid,
|
||||
# total_max=UCL,
|
||||
# total_min=LCL
|
||||
# )
|
||||
#
|
||||
# model_list = ["RNet_D", "RNet_L","RNet_S","RNet_MSE"]
|
||||
# predict_list = ["predict1","predict2","predict3"]
|
||||
#
|
||||
# for model_name in model_list:
|
||||
# for predict_head in predict_list:
|
||||
# # model_name = "RNet_L"
|
||||
# # predict_head = "predict3"
|
||||
# mse = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}/banda\{0}_banda_mse_{1}.csv'.format(
|
||||
# model_name, predict_head),
|
||||
# delimiter=',')
|
||||
# max = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}/banda\{0}_banda_max_{1}.csv'.format(
|
||||
# model_name, predict_head),
|
||||
# delimiter=',')
|
||||
# data, mid, UCL, LCL = EWMA(mse)
|
||||
#
|
||||
# plot_EWMA(
|
||||
# model_name=model_name + predict_head+"_banda",
|
||||
# total_MSE=data,
|
||||
# total_mid=mid,
|
||||
# total_max=UCL,
|
||||
# total_min=LCL
|
||||
# )
|
||||
|
||||
# model_name = "RNet_D"
|
||||
# predict_head = "predict1"
|
||||
# mse = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}\{0}_timestamp120_feature10_mse_{1}.csv'.format(
|
||||
# model_name, predict_head),
|
||||
# delimiter=',')
|
||||
# max = np.loadtxt(
|
||||
# 'E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\{0}\{0}_timestamp120_feature10_max_{1}.csv'.format(
|
||||
# model_name, predict_head),
|
||||
# delimiter=',')
|
||||
# data, mid, UCL, LCL = EWMA(mse)
|
||||
#
|
||||
# plot_EWMA(
|
||||
# model_name=model_name + predict_head ,
|
||||
# total_MSE=data,
|
||||
# total_mid=mid,
|
||||
# total_max=UCL,
|
||||
# total_min=LCL
|
||||
# )
|
||||
|
||||
# test_result("E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_345\RNet_345_timestamp120_feature10_result.csv")
|
||||
# test_mse(fi)
|
||||
# test_result(
|
||||
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_timestamp120_feature10_result.csv')
|
||||
# test_corr()
|
||||
# acc()
|
||||
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
|
||||
# test_bar(list)
|
||||
|
||||
# list=[98.56,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
|
||||
# plot_FNR1(list)
|
||||
# #
|
||||
# list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
||||
list=[3.43,1.99,1.92,2.17,1.8,1.81,1.78,1.8,0.6]
|
||||
plot_FNR2(list)
|
||||
# list = [
|
||||
# 3.77,
|
||||
# 2.64,
|
||||
# 2.35,
|
||||
#
|
||||
# 2.05,
|
||||
# 1.76,
|
||||
# 1.09,
|
||||
# 0.76,
|
||||
# 0.82,
|
||||
# 1.10,
|
||||
# 0.58,
|
||||
# 0,
|
||||
# 0.03,
|
||||
# 0.02,
|
||||
# 0.21,
|
||||
# 2.18,
|
||||
# ]
|
||||
#
|
||||
# # test_bar(list)
|
||||
# list = [
|
||||
# 64.63,
|
||||
# 65.26,
|
||||
# 65.11,
|
||||
#
|
||||
# 66.6,
|
||||
# 67.15,
|
||||
# 73.86,
|
||||
# 66.28,
|
||||
# 75.24,
|
||||
# 73.98,
|
||||
# 76.7,
|
||||
# 98.86,
|
||||
# 99.38,
|
||||
# 99.79,
|
||||
# 70.41,
|
||||
# 66.71
|
||||
# ]
|
||||
# # acc(list)
|
||||
#
|
||||
list=[98.56,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74,88.36,95.52]
|
||||
plot_FNR1(list)
|
||||
# # # #
|
||||
# # list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
||||
# list=[3.43,0,0,0,0,0,1.8,0,0,0,0,0,0.6,0]
|
||||
# # list = [98.56, 98.95, 99.95, 96.1, 95, 99.65, 76.25, 72.64, 75.87, 68.74, 88.36, 95.52]
|
||||
# plot_FNR2(list)
|
||||
|
||||
# 查看网络某一层的权重
|
||||
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
|
||||
|
|
@ -810,11 +1048,11 @@ if __name__ == '__main__':
|
|||
# 单独预测图
|
||||
# plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
|
||||
# plot_mse_single('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D/banda\RNet_D_banda_mse_predict1.csv')
|
||||
#画3d图
|
||||
# plot_mse_single('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_3\RNet_3_timestamp120_feature10_result.csv')
|
||||
# 画3d图
|
||||
# plot_3d()
|
||||
|
||||
|
||||
#原始数据图
|
||||
# 原始数据图
|
||||
# file_names='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv'
|
||||
# data= np.loadtxt(file_names,delimiter=',')
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring_banda import Joint_Monitoring
|
||||
|
||||
|
|
@ -40,7 +39,7 @@ save_name = "../hard_model/weight/{0}_epoch16_0.0009_0.0014/weight".format(model
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_epoch24_9875_9867/weight".format(model_name,
|
||||
save_step_two_name = "../hard_model/two_weight/temp{0}/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -385,7 +384,7 @@ def train_step_one(train_data, train_label1, train_label2):
|
|||
k = k + 1
|
||||
val_loss, val_accuracy = model.get_val_loss(val_data=val_data, val_label1=val_label1, val_label2=val_label2,
|
||||
is_first_time=True)
|
||||
SaveBestModel(model=model, save_name=save_name, history_loss=history_val_loss, loss_value=val_loss.numpy())
|
||||
SaveBestModel(model=model, save_name=save_name, history_loss=history_val_loss, loss_value=val_loss.numpy(),is_all=True)
|
||||
# SaveBestH5Model(model=model, save_name=save_name, history_loss=history_val_loss, loss_value=val_loss.numpy())
|
||||
history_val_loss.append(val_loss)
|
||||
history_loss.append(loss_value.numpy())
|
||||
|
|
@ -586,28 +585,32 @@ if __name__ == '__main__':
|
|||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
# train_step_one(train_data=train_data_healthy[:128, :, :], train_label1=train_label1_healthy[:128, :],train_label2=train_label2_healthy[:128, ])
|
||||
# 整体训练
|
||||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
# step_one_model.load_weights(save_name)
|
||||
#
|
||||
# step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
|
||||
#### TODO 第二步训练
|
||||
### healthy_data.shape: (300333,120,10)
|
||||
### unhealthy_data.shape: (16594,10)
|
||||
|
||||
#### 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
step_one_model = Joint_Monitoring()
|
||||
step_one_model.load_weights(save_name)
|
||||
|
||||
step_two_model = Joint_Monitoring()
|
||||
step_two_model.load_weights(save_name)
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
# train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
|
||||
# healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
# unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
|
||||
healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
|
||||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
### TODO 测试测试集
|
||||
# step_one_model = Joint_Monitoring()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention
|
||||
from model.ChannelAttention.Dynamic_channelAttention import DynamicChannelAttention
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.compare.RNet_C import Joint_Monitoring
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention
|
||||
import time
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.compare.RNet import Joint_Monitoring
|
||||
|
||||
|
|
@ -591,10 +592,21 @@ if __name__ == '__main__':
|
|||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
step_one_model = Joint_Monitoring()
|
||||
step_one_model.load_weights(save_name)
|
||||
|
||||
# step_one_model.build(input_shape=(30,120,10))
|
||||
# step_one_model.summary()
|
||||
#
|
||||
# step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
step_one_model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
# #### TODO 计算MSE
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import seaborn as sns
|
|||
from sklearn.model_selection import train_test_split
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from keras.callbacks import EarlyStopping
|
||||
|
||||
import time
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
|
|
@ -68,22 +68,7 @@ unhealthy_patience = 5
|
|||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 10} # 设置坐标标签的字体大小,字体
|
||||
|
||||
|
||||
# train_data = np.load("../../../data/train_data.npy")
|
||||
# train_label = np.load("../../../data/train_label.npy")
|
||||
# test_data = np.load("../../../data/test_data.npy")
|
||||
# test_label = np.load("../../../data/test_label.npy")
|
||||
|
||||
|
||||
# CIFAR_100_data = tf.keras.datasets.cifar100
|
||||
# (train_data, train_label), (test_data, test_label) = CIFAR_100_data.load_data()
|
||||
# train_data=np.array(train_data)
|
||||
# train_label=np.array(train_label)
|
||||
# print(train_data.shape)
|
||||
# print(train_label.shape)
|
||||
# print(train_data)
|
||||
# print(test_data)
|
||||
#
|
||||
#
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
|
|
@ -328,6 +313,14 @@ if __name__ == '__main__':
|
|||
# model.save("./model/ResNet.h5")
|
||||
model = tf.keras.models.load_model("model/ResNet_banda/ResNet_banda_epoch10_9884.h5")
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
# 结果展示
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
|
|
|
|||
|
|
@ -0,0 +1,496 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,与JointNet相同深度,进行预测
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "./mse/DCConv/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/DCConv/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:415548行均是正常值(2019/7.30 00:00:00 - 2019/9/18 11:14:00)
|
||||
从415549:432153行均是异常值(2019/9/18 11:21:01 - 2021/1/18 00:00:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3)(input)
|
||||
LSTM = tf.keras.layers.Conv1D(20, 3)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(20, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(40, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(80, return_sequences=False)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=64)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=128)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(40, 3, padding="causal",dilation_rate=2)(LSTM)
|
||||
|
||||
# LSTM = LSTM[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1,9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
model = DCConv_Model()
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
|
||||
# history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
# train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
# batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
# callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
## TODO testing
|
||||
# # test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
# model = tf.keras.models.load_model(save_name)
|
||||
# # mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
#
|
||||
# start = time.time()
|
||||
# # 中间写上代码块
|
||||
#
|
||||
# model.predict(train_data_healthy, batch_size=32)
|
||||
# end = time.time()
|
||||
# print("data_size:", train_data_healthy.shape)
|
||||
# print('Running time: %s Seconds' % (end - start))
|
||||
#
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# 单次测试
|
||||
# getResult(newModel,
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# :],
|
||||
# healthy_label=train_label1_healthy[
|
||||
# healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy,isSave=False)
|
||||
# mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,497 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,与JointNet相同深度,进行预测
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse/CNN/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse/CNN/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\jb4q_8_delete_all_zero.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_all_zero.csv是删除了除异常以外的所有0值的文件
|
||||
文件从0:300454行均是正常值(2019/7.30 00:00:00 - 2019/9/18 11:21:00)
|
||||
从300455:317052行均是异常值(2019/9/18 11:21:01 - 2019/9/29 23:59:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 300454
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 317052
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3)(input)
|
||||
LSTM = tf.keras.layers.Conv1D(20, 3)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(20, 3)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(40, 3)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(80, 3)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=64)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=128)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(40, 3, padding="causal",dilation_rate=2)(LSTM)
|
||||
|
||||
LSTM = LSTM[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1,9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
model = DCConv_Model()
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
|
||||
history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
## TODO testing
|
||||
# test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
# newModel = tf.keras.models.load_model(save_name)
|
||||
# mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# 单次测试
|
||||
# getResult(newModel,
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# :],
|
||||
# healthy_label=train_label1_healthy[
|
||||
# healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy,isSave=True)
|
||||
# mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,496 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,与JointNet相同深度,进行预测
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "./mse/DCConv/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/DCConv/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:415548行均是正常值(2019/7.30 00:00:00 - 2019/9/18 11:14:00)
|
||||
从415549:432153行均是异常值(2019/9/18 11:21:01 - 2021/1/18 00:00:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=2)(input)
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=4)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=8)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=16)(LSTM)
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=32)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=64)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=128)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(40, 3, padding="causal",dilation_rate=2)(LSTM)
|
||||
|
||||
LSTM = LSTM[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1,9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
model = DCConv_Model()
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
|
||||
history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
## TODO testing
|
||||
# test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
# newModel = tf.keras.models.load_model(save_name)
|
||||
# mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# 单次测试
|
||||
# getResult(newModel,
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# :],
|
||||
# healthy_label=train_label1_healthy[
|
||||
# healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy,isSave=True)
|
||||
# mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,493 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,与JointNet相同深度,进行预测
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "./mse/DCConv/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/DCConv/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:415548行均是正常值(2019/7.30 00:00:00 - 2019/9/18 11:14:00)
|
||||
从415549:432153行均是异常值(2019/9/18 11:21:01 - 2021/1/18 00:00:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM = tf.keras.layers.GRU(10, return_sequences=True)(input)
|
||||
LSTM = tf.keras.layers.GRU(20, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(20, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(40, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(80, return_sequences=False)(LSTM)
|
||||
|
||||
# LSTM = LSTM[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1,9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
model = DCConv_Model()
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
|
||||
# history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
# train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
# batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
# callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
## TODO testing
|
||||
# test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
model = tf.keras.models.load_model(save_name)
|
||||
# mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# 单次测试
|
||||
# getResult(newModel,
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# :],
|
||||
# healthy_label=train_label1_healthy[
|
||||
# healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy,isSave=True)
|
||||
# mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,520 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,与JointNet相同深度,进行预测
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "./mse/DCConv/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/DCConv/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:415548行均是正常值(2019/7.30 00:00:00 - 2019/9/18 11:14:00)
|
||||
从415549:432153行均是异常值(2019/9/18 11:21:01 - 2021/1/18 00:00:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM1 = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=2)(input)
|
||||
LSTM2 = tf.keras.layers.Conv1D(10, 2, padding="causal", dilation_rate=2)(input)
|
||||
LSTM3 = tf.keras.layers.Conv1D(10, 1, padding="causal", dilation_rate=2)(input)
|
||||
|
||||
t1 = tf.add(tf.add(LSTM1, LSTM2), LSTM3)
|
||||
|
||||
LSTM1 = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=4)(t1)
|
||||
LSTM2 = tf.keras.layers.Conv1D(10, 2, padding="causal", dilation_rate=4)(t1)
|
||||
LSTM3 = tf.keras.layers.Conv1D(10, 1, padding="causal", dilation_rate=4)(t1)
|
||||
|
||||
t2 = tf.add(tf.add(LSTM1, LSTM2), LSTM3)
|
||||
|
||||
|
||||
|
||||
LSTM1 = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=8)(t2)
|
||||
LSTM2 = tf.keras.layers.Conv1D(10, 2, padding="causal", dilation_rate=8)(t2)
|
||||
LSTM3 = tf.keras.layers.Conv1D(10, 1, padding="causal", dilation_rate=8)(t2)
|
||||
|
||||
t3 = tf.add(tf.add(LSTM1, LSTM2), LSTM3)
|
||||
|
||||
|
||||
LSTM1 = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=16)(t3)
|
||||
LSTM2 = tf.keras.layers.Conv1D(10, 2, padding="causal", dilation_rate=16)(t3)
|
||||
LSTM3 = tf.keras.layers.Conv1D(10, 1, padding="causal", dilation_rate=16)(t3)
|
||||
|
||||
t4 = tf.add(tf.add(LSTM1, LSTM2), LSTM3)
|
||||
|
||||
|
||||
|
||||
LSTM1 = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=32)(t4)
|
||||
LSTM2 = tf.keras.layers.Conv1D(10, 2, padding="causal", dilation_rate=32)(t4)
|
||||
LSTM3 = tf.keras.layers.Conv1D(10, 1, padding="causal", dilation_rate=32)(t4)
|
||||
|
||||
t5 = tf.add(tf.add(LSTM1, LSTM2), LSTM3)
|
||||
|
||||
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=64)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=128)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(40, 3, padding="causal",dilation_rate=2)(LSTM)
|
||||
|
||||
LSTM = t5[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
# 单次测试
|
||||
model = DCConv_Model()
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
#
|
||||
# history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
# train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
# batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
# callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
## TODO testing
|
||||
# test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
newModel = tf.keras.models.load_model(save_name)
|
||||
# 单次测试
|
||||
# getResult(newModel,
|
||||
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# :],
|
||||
# healthy_label=train_label1_healthy[
|
||||
# healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy, isSave=True)
|
||||
# mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 10:43
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 10:43
|
||||
@Usage :
|
||||
@Desc : 计算transformer的参数和时间复杂度 6层self_attention
|
||||
'''
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import Model, layers, initializers
|
||||
import numpy as np
|
||||
from model.SelfAttention.SelfAttention import Block
|
||||
|
||||
|
||||
class Transformer(Model):
|
||||
|
||||
# depth表示的是重复encoder block的次数,num_heads表示的是在multi-head self-attention中head的个数
|
||||
# MLP block中有一个Pre_logist,这里指的是,当在较大的数据集上学习的时候Pre_logist就表示一个全连接层加上一个tanh激活函数
|
||||
# 当在较小的数据集上学习的时候,Pre_logist是没有的,而这里的representation_size表示的就是Pre_logist中全连接层的节点个数
|
||||
# num_classes表示分类的类数
|
||||
def __init__(self, embed_dim=768,
|
||||
depth=12, num_heads=12, qkv_bias=True, qk_scale=None,
|
||||
drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0.,
|
||||
representation_size=None, num_classes=1000, name="ViT-B/16"):
|
||||
super(Transformer, self).__init__(name=name)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_scale = qk_scale
|
||||
self.drop_ratio = drop_ratio
|
||||
self.attn_drop_ratio = attn_drop_ratio
|
||||
self.drop_path_ratio = drop_path_ratio
|
||||
self.representation_size = representation_size
|
||||
self.num_classes = num_classes
|
||||
|
||||
dpr = np.linspace(0., drop_path_ratio, depth) # stochastic depth decay rule
|
||||
# 用一个for循环重复Block模块
|
||||
# 在用droppath时的drop_path_ratio是由0慢慢递增到我们所指定的drop_path_ratio的
|
||||
# 所以我们在构建Block时,这里的drop_path_ratio时变化的,所以用 np.linspace方法创建一个等差数列来初始化drop_path_ratio
|
||||
self.blocks = [Block(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
|
||||
drop_path_ratio=dpr[i], name="encoderblock_{}".format(i))
|
||||
for i in range(depth)]
|
||||
|
||||
self.norm = layers.LayerNormalization(epsilon=1e-6, name="encoder_norm")
|
||||
|
||||
# 接下来,如果传入了representation_size,就构建一个全连接层,激活函数为tanh
|
||||
# 如果没有传入的话,就不做任何操作
|
||||
if representation_size:
|
||||
self.has_logits = True
|
||||
self.pre_logits = layers.Dense(representation_size, activation="tanh", name="pre_logits")
|
||||
else:
|
||||
self.has_logits = False
|
||||
self.pre_logits = layers.Activation("linear")
|
||||
|
||||
# 定义最后一个全连接层,节点个数就是我们的分类个数num_classes
|
||||
self.head = layers.Dense(num_classes, name="head", kernel_initializer=initializers.Zeros())
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'embed_dim': self.embed_dim,
|
||||
'depth': self.depth,
|
||||
'num_heads': self.num_heads,
|
||||
'qkv_bias': self.qkv_bias,
|
||||
'qk_scale': self.qk_scale,
|
||||
'drop_ratio': self.drop_ratio,
|
||||
'attn_drop_ratio': self.attn_drop_ratio,
|
||||
'drop_path_ratio': self.drop_path_ratio,
|
||||
'representation_size': self.representation_size,
|
||||
'num_classes': self.num_classes
|
||||
}
|
||||
)
|
||||
base_config = super(Transformer, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
# [B, H, W, C] -> [B, num_patches, embed_dim]
|
||||
x = inputs # [B, 196, 768]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, training=training)
|
||||
|
||||
x = self.norm(x)
|
||||
# 这里是提取class_toke的输出,然后用切片的方式,而刚刚是将class_toke拼接在最前面的
|
||||
# 所以这里用切片的方式,去取class_toke的输出,并将它传递给pre_logits
|
||||
x = self.pre_logits(x[:, 0])
|
||||
# 最后传递给head
|
||||
x = self.head(x)
|
||||
# 为什么只用class_toke对应的输出,而不用每一个patches对应的输出呢?
|
||||
# 可以参考原文bird 网络
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 使用方式
|
||||
Transformer(embed_dim=10, depth=8, num_heads=1, num_classes=10)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 10:43
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,569 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/11 18:52
|
||||
@Usage : 对比实验,使用四分位图技术
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
from model.Joint_Monitoring.Joint_Monitoring3 import Joint_Monitoring
|
||||
|
||||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from keras.callbacks import EarlyStopping
|
||||
import random
|
||||
import time
|
||||
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "DCConv"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./trianed/{0}_{1}_{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../hard_model/two_weight/{0}_timestamp{1}_feature{2}_weight_epoch14/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_mse_name = "./mse/DCConv/banda/mse.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/DCConv/banda/max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
# save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
# save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
# time_stamp,
|
||||
# feature_num,
|
||||
# batch_size,
|
||||
# EPOCH)
|
||||
'''文件名'''
|
||||
'''文件名'''
|
||||
save_mse_name=r"./compare/mse/JM_banda/{0}_result.csv".format(model_name)
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:96748行均是正常值(2019/12.30 00:00:00 - 2020/3/11 05:58:00)
|
||||
从96748:107116行均是异常值(2020/3/11 05:58:01 - 2021/3/18 11:04:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
def shuffle(train_data, train_label1, train_label2, is_split: bool = False, split_size: float = 0.2):
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(train_data,
|
||||
train_label1,
|
||||
train_label2,
|
||||
test_size=split_size,
|
||||
shuffle=True,
|
||||
random_state=100)
|
||||
if is_split:
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
train_data = np.concatenate([train_data, test_data], axis=0)
|
||||
train_label1 = np.concatenate([train_label1, test_label1], axis=0)
|
||||
train_label2 = np.concatenate([train_label2, test_label2], axis=0)
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2
|
||||
pass
|
||||
|
||||
|
||||
def split_test_data(healthy_data, healthy_label1, healthy_label2, unhealthy_data, unhealthy_label1, unhealthy_label2,
|
||||
split_size: float = 0.2, shuffle: bool = True):
|
||||
data = np.concatenate([healthy_data, unhealthy_data], axis=0)
|
||||
label1 = np.concatenate([healthy_label1, unhealthy_label1], axis=0)
|
||||
label2 = np.concatenate([healthy_label2, unhealthy_label2], axis=0)
|
||||
(train_data, test_data, train_label1, test_label1, train_label2, test_label2) = train_test_split(data,
|
||||
label1,
|
||||
label2,
|
||||
test_size=split_size,
|
||||
shuffle=shuffle,
|
||||
random_state=100)
|
||||
|
||||
# print(train_data.shape)
|
||||
# print(train_label1.shape)
|
||||
# print(train_label2.shape)
|
||||
# print(train_data.shape)
|
||||
|
||||
return train_data, train_label1, train_label2, test_data, test_label1, test_label2
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
size, length, dims = test_data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = test_data[epoch:epoch + batch_size, :, :]
|
||||
_, _, _, output4 = step_two_model.call(each_test_data, is_first_time=False)
|
||||
total_result.append(output4)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
if isPlot:
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
# plt.text(35000, 0.9, "Truth Fault", fontsize=10, color='black', verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xlabel("time")
|
||||
plt.ylabel("confience")
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10})
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def DCConv_Model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
input = tf.cast(input, tf.float32)
|
||||
|
||||
LSTM = tf.keras.layers.Conv1D(10, 3)(input)
|
||||
LSTM = tf.keras.layers.Conv1D(20, 3)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(20, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(40, return_sequences=True)(LSTM)
|
||||
LSTM = tf.keras.layers.GRU(80, return_sequences=False)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=64)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(10, 3, padding="causal", dilation_rate=128)(LSTM)
|
||||
# LSTM = tf.keras.layers.Conv1D(40, 3, padding="causal",dilation_rate=2)(LSTM)
|
||||
|
||||
# LSTM = LSTM[:, -1, :]
|
||||
# bn = tf.keras.layers.BatchNormalization()(LSTM)
|
||||
|
||||
# d1 = tf.keras.layers.Dense(20)(LSTM)
|
||||
# bn = tf.keras.layers.BatchNormalization()(d1)
|
||||
|
||||
output = tf.keras.layers.Dense(128, name='output1')(LSTM)
|
||||
output = tf.keras.layers.Dense(10, name='output')(output)
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True, predictI: int = 1):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = True, predictI: int = 1):
|
||||
# TODO 计算MSE确定阈值
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
faultNum = mse[mse[:] > max[0]].__len__()
|
||||
# for i in range(total):
|
||||
# if (mse[i] > max[i]):
|
||||
# faultNum += 1
|
||||
# faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
|
||||
total_mse = np.concatenate([mse, mse1], axis=0)
|
||||
total_max = np.broadcast_to(max[0], shape=[total_mse.shape[0], ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
total_mean = np.broadcast_to(mean[0], shape=[total_mse.shape[0], ])
|
||||
if isSave:
|
||||
save_mse_name1 = save_mse_name
|
||||
save_max_name1 = save_max_name
|
||||
|
||||
np.savetxt(save_mse_name1, total_mse, delimiter=',')
|
||||
np.savetxt(save_max_name1, total_max, delimiter=',')
|
||||
|
||||
all, = mse1.shape
|
||||
|
||||
missNum = mse1[mse1[:] < max[0]].__len__()
|
||||
|
||||
print("all:", all)
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
|
||||
plt.figure(random.randint(1, 100))
|
||||
plt.plot(total_max)
|
||||
plt.plot(total_mse)
|
||||
plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def iqr_outliers(df: pd.DataFrame):
|
||||
q1 = df.quantile(0.25)
|
||||
q3 = df.quantile(0.75)
|
||||
iqr = q3 - q1
|
||||
Lower_tail = q1 - 1.5 * iqr
|
||||
Upper_tail = q3 + 1.5 * iqr
|
||||
outlier = []
|
||||
for i in df.iloc[:, 0]:
|
||||
if i > float(Upper_tail) or i < float(Lower_tail): # float限定
|
||||
outlier.append(i)
|
||||
print("Outliers:", outlier)
|
||||
|
||||
|
||||
def iqr_outliers_np_all(df):
|
||||
length, feature = df.shape
|
||||
toatl_Outliers = set()
|
||||
for i in range(feature):
|
||||
cur = df[:, i]
|
||||
cur = np.array(cur)
|
||||
q1 = np.percentile(cur, [25])
|
||||
q3 = np.percentile(cur, [75])
|
||||
iqr = q3 - q1
|
||||
Lower_tail = q1 - 1.5 * iqr
|
||||
Upper_tail = q3 + 1.5 * iqr
|
||||
cur_Outliers = []
|
||||
for i in range(len(cur)):
|
||||
if cur[i] > float(Upper_tail) or cur[i] < float(Lower_tail): # float限定
|
||||
toatl_Outliers.add(i)
|
||||
cur_Outliers.append(i)
|
||||
print("cur_Outliers.shape:", len(cur_Outliers))
|
||||
print("cur_Outliers:", cur_Outliers)
|
||||
print("Outliers.shape:", len(toatl_Outliers))
|
||||
print("Outliers:", toatl_Outliers)
|
||||
unhealthy_outlier = []
|
||||
for s in toatl_Outliers:
|
||||
if s >= healthy_date:
|
||||
unhealthy_outlier.append(s)
|
||||
print("unhealthy_outlier.shape:", len(unhealthy_outlier))
|
||||
print("unhealthy_outlier.shape:", len(unhealthy_outlier)/(unhealthy_date-healthy_date))
|
||||
print("unhealthy_outlier:", unhealthy_outlier)
|
||||
return sorted(unhealthy_outlier)
|
||||
|
||||
|
||||
def sigma_outliers_np_all(df):
|
||||
length, feature = df.shape
|
||||
toatl_Outliers = set()
|
||||
for i in range(feature):
|
||||
cur = df[:, i]
|
||||
cur = np.array(cur)
|
||||
mean = np.mean(cur[:200000,])
|
||||
sigma = np.sqrt(np.var(cur[:200000,]))
|
||||
|
||||
Lower_tail = mean - 3 * sigma
|
||||
Upper_tail = mean + 3 * sigma
|
||||
cur_Outliers = []
|
||||
for i in range(len(cur)):
|
||||
if cur[i] > float(Upper_tail) or cur[i] < float(Lower_tail): # float限定
|
||||
toatl_Outliers.add(i)
|
||||
cur_Outliers.append(i)
|
||||
print("cur_Outliers.shape:", len(cur_Outliers))
|
||||
print("cur_Outliers:", cur_Outliers)
|
||||
print("Outliers.shape:", len(toatl_Outliers))
|
||||
print("Outliers:", toatl_Outliers)
|
||||
unhealthy_outlier = []
|
||||
for s in toatl_Outliers:
|
||||
if s >= healthy_date:
|
||||
unhealthy_outlier.append(s)
|
||||
print("unhealthy_outlier.shape:", len(unhealthy_outlier))
|
||||
print("unhealthy_outlier.shape:", len(unhealthy_outlier)/(unhealthy_date-healthy_date))
|
||||
print("unhealthy_outlier:", unhealthy_outlier)
|
||||
return sorted(unhealthy_outlier)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
|
||||
total_train_data = total_data[:unhealthy_date, :]
|
||||
# unhealthy_outlier = iqr_outliers_np_all(total_train_data)
|
||||
unhealthy_outlier = sigma_outliers_np_all(total_train_data)
|
||||
# #### TODO 第一步训练
|
||||
# # 单次测试
|
||||
# model = DCConv_Model()
|
||||
#
|
||||
# checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
# filepath=save_name,
|
||||
# monitor='val_loss',
|
||||
# verbose=2,
|
||||
# save_best_only=True,
|
||||
# mode='min')
|
||||
# lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
#
|
||||
# model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
# model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
# model.summary()
|
||||
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
#
|
||||
# # history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
# # train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
# # batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
# # callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
#
|
||||
# ## TODO testing
|
||||
# # # test_data, test_label = get_training_data(total_data[:healthy_date, :])
|
||||
# # model = tf.keras.models.load_model(save_name)
|
||||
# # # mse, mean, max = get_MSE(test_data, test_label, new_model=newModel)
|
||||
# #
|
||||
# # start = time.time()
|
||||
# # # 中间写上代码块
|
||||
# #
|
||||
# # model.predict(train_data_healthy, batch_size=32)
|
||||
# # end = time.time()
|
||||
# # print("data_size:", train_data_healthy.shape)
|
||||
# # print('Running time: %s Seconds' % (end - start))
|
||||
# #
|
||||
# healthy_size, _, _ = train_data_healthy.shape
|
||||
# unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
# all_data, _, _ = get_training_data_overlapping(
|
||||
# total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
#
|
||||
# newModel = tf.keras.models.load_model(save_name)
|
||||
# # 单次测试
|
||||
# # getResult(newModel,
|
||||
# # healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200,
|
||||
# # :],
|
||||
# # healthy_label=train_label1_healthy[
|
||||
# # healthy_size - 2 * unhealthy_size:healthy_size - 2 * unhealthy_size + 200, :],
|
||||
# # unhealthy_data=train_data_unhealthy[:200, :], unhealthy_label=train_label1_unhealthy[:200, :],isSave=True)
|
||||
# getResult(newModel, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy,isSave=False)
|
||||
# # mse, mean, max = get_MSE(train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# # train_label1_healthy[healthy_size - 2 * unhealthy_size:, :], new_model=newModel)
|
||||
pass
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 14:23
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from condition_monitoring.data_deal import loadData_daban as loadData
|
||||
# from model.Joint_Monitoring.Joint_Monitoring_banda import Joint_Monitoring
|
||||
|
||||
# from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
from condition_monitoring.返修.complete.Transformer import Transformer
|
||||
from model.SelfAttention.SelfAttention import Block
|
||||
from keras.callbacks import EarlyStopping
|
||||
import time
|
||||
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
feature_num = 10
|
||||
batch_size = 32
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "transformer"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "../model/joint/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "../model/joint_two/{0}_timestamp{1}_feature{2}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = r"./compare/mse/JM_banda/{0}_result.csv".format(model_name)
|
||||
'''文件名'''
|
||||
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
|
||||
|
||||
'''
|
||||
文件说明:jb4q_8_delete_total_zero.csv是删除了只删除了全是0的列的文件
|
||||
文件从0:96748行均是正常值(2019/12.30 00:00:00 - 2020/3/11 05:58:00)
|
||||
从96748:107116行均是异常值(2020/3/11 05:58:01 - 2021/3/18 11:04:00)
|
||||
'''
|
||||
'''文件参数'''
|
||||
# 最后正常的时间点
|
||||
healthy_date = 96748
|
||||
# 最后异常的时间点
|
||||
unhealthy_date = 107116
|
||||
# 异常容忍程度
|
||||
unhealthy_patience = 5
|
||||
|
||||
|
||||
def remove(data, time_stamp=time_stamp):
|
||||
rows, cols = data.shape
|
||||
print("remove_data.shape:", data.shape)
|
||||
num = int(rows / time_stamp)
|
||||
|
||||
return data[:num * time_stamp, :]
|
||||
pass
|
||||
|
||||
|
||||
# 不重叠采样
|
||||
def get_training_data(data, time_stamp: int = time_stamp):
|
||||
removed_data = remove(data=data)
|
||||
rows, cols = removed_data.shape
|
||||
print("removed_data.shape:", data.shape)
|
||||
print("removed_data:", removed_data)
|
||||
train_data = np.reshape(removed_data, [-1, time_stamp, cols])
|
||||
print("train_data:", train_data)
|
||||
batchs, time_stamp, cols = train_data.shape
|
||||
|
||||
for i in range(1, batchs):
|
||||
each_label = np.expand_dims(train_data[i, 0, :], axis=0)
|
||||
if i == 1:
|
||||
train_label = each_label
|
||||
else:
|
||||
train_label = np.concatenate([train_label, each_label], axis=0)
|
||||
|
||||
print("train_data.shape:", train_data.shape)
|
||||
print("train_label.shape", train_label.shape)
|
||||
return train_data[:-1, :], train_label
|
||||
|
||||
|
||||
# 重叠采样
|
||||
def get_training_data_overlapping(data, time_stamp: int = time_stamp, is_Healthy: bool = True):
|
||||
rows, cols = data.shape
|
||||
train_data = np.empty(shape=[rows - time_stamp - 1, time_stamp, cols])
|
||||
train_label = np.empty(shape=[rows - time_stamp - 1, cols])
|
||||
for i in range(rows):
|
||||
if i + time_stamp >= rows:
|
||||
break
|
||||
if i + time_stamp < rows - 1:
|
||||
train_data[i] = data[i:i + time_stamp]
|
||||
train_label[i] = data[i + time_stamp]
|
||||
|
||||
print("重叠采样以后:")
|
||||
print("data:", train_data) # (300334,120,10)
|
||||
print("label:", train_label) # (300334,10)
|
||||
|
||||
if is_Healthy:
|
||||
train_label2 = np.ones(shape=[train_label.shape[0]])
|
||||
else:
|
||||
train_label2 = np.zeros(shape=[train_label.shape[0]])
|
||||
|
||||
print("label2:", train_label2)
|
||||
|
||||
return train_data, train_label, train_label2
|
||||
|
||||
|
||||
# 归一化
|
||||
def normalization(data):
|
||||
rows, cols = data.shape
|
||||
print("归一化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 归一化
|
||||
max = np.max(data, axis=0)
|
||||
max = np.broadcast_to(max, [rows, cols])
|
||||
min = np.min(data, axis=0)
|
||||
min = np.broadcast_to(min, [rows, cols])
|
||||
|
||||
data = (data - min) / (max - min)
|
||||
print("归一化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# 正则化
|
||||
def Regularization(data):
|
||||
rows, cols = data.shape
|
||||
print("正则化之前:", data)
|
||||
print(data.shape)
|
||||
print("======================")
|
||||
|
||||
# 正则化
|
||||
mean = np.mean(data, axis=0)
|
||||
mean = np.broadcast_to(mean, shape=[rows, cols])
|
||||
dst = np.sqrt(np.var(data, axis=0))
|
||||
dst = np.broadcast_to(dst, shape=[rows, cols])
|
||||
data = (data - mean) / dst
|
||||
print("正则化之后:", data)
|
||||
print(data.shape)
|
||||
|
||||
return data
|
||||
pass
|
||||
|
||||
|
||||
def EWMA(data, K=K, namuda=namuda):
|
||||
# t是啥暂时未知
|
||||
t = 0
|
||||
mid = np.mean(data, axis=0)
|
||||
standard = np.sqrt(np.var(data, axis=0))
|
||||
UCL = mid + K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
LCL = mid - K * standard * np.sqrt(namuda / (2 - namuda) * (1 - (1 - namuda) ** 2 * t))
|
||||
return mid, UCL, LCL
|
||||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
|
||||
dims, = mse.shape
|
||||
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
|
||||
# plt.plot(max)
|
||||
# plt.plot(mse)
|
||||
# plt.plot(mean)
|
||||
# # plt.plot(min)
|
||||
# plt.show()
|
||||
#
|
||||
#
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test(step_one_model, step_two_model, test_data, test_label1, test_label2):
|
||||
history_loss = []
|
||||
history_val_loss = []
|
||||
|
||||
val_loss, val_accuracy = step_two_model.get_val_loss(val_data=test_data, val_label1=test_label1,
|
||||
val_label2=test_label2,
|
||||
is_first_time=False, step_one_model=step_one_model)
|
||||
|
||||
history_val_loss.append(val_loss)
|
||||
print("val_accuracy:", val_accuracy)
|
||||
print("val_loss:", val_loss)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
|
||||
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
|
||||
total_data[:healthy_date, :], is_Healthy=True)
|
||||
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
|
||||
total_data[healthy_date - time_stamp + unhealthy_patience:unhealthy_date, :],
|
||||
is_Healthy=False)
|
||||
#### TODO 第一步训练
|
||||
|
||||
####### TODO 训练
|
||||
model = Transformer(embed_dim=10, depth=5, num_heads=1, num_classes=10,representation_size=128)
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.mse)
|
||||
model.build(input_shape=(batch_size, time_stamp, feature_num))
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3, mode='min', verbose=1)
|
||||
|
||||
history = model.fit(train_data_healthy[:train_data_healthy.shape[0] // 7, :, :],
|
||||
train_label1_healthy[:train_label1_healthy.shape[0] // 7, ], epochs=EPOCH,
|
||||
batch_size=batch_size * 10, validation_split=0.2, shuffle=True, verbose=1,
|
||||
callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
#
|
||||
#
|
||||
#
|
||||
# #### TODO 测试
|
||||
|
||||
#
|
||||
start = time.time()
|
||||
# 中间写上代码块
|
||||
|
||||
model.predict(train_data_healthy, batch_size=32)
|
||||
end = time.time()
|
||||
print("data_size:", train_data_healthy.shape)
|
||||
print('Running time: %s Seconds' % (end - start))
|
||||
|
||||
trained_model = tf.keras.models.load_model(save_name, custom_objects={'Block': Block})
|
||||
#
|
||||
#
|
||||
#
|
||||
# # 使用已知的点进行预测
|
||||
#
|
||||
# pass
|
||||
|
|
@ -97,6 +97,16 @@ class DynamicPooling(layers.Layer):
|
|||
self.pool_size = pool_size
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'pool_size': self.pool_size
|
||||
}
|
||||
)
|
||||
base_config = super(DynamicPooling, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def build(self, input_shape):
|
||||
if len(input_shape) != 3:
|
||||
raise ValueError('Inputs to `DynamicChannelAttention` should have rank 3. '
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention, DynamicPooling
|
||||
from condition_monitoring.data_deal import loadData
|
||||
from model.LossFunction.smooth_L1_Loss import SmoothL1Loss
|
||||
import math
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ import tensorflow.keras as keras
|
|||
from tensorflow.keras import *
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention, DynamicPooling
|
||||
from condition_monitoring.data_deal import loadData
|
||||
from model.ChannelAttention.Dynamic_channelAttention import DynamicChannelAttention, DynamicPooling
|
||||
from model.LossFunction.smooth_L1_Loss import SmoothL1Loss
|
||||
import math
|
||||
|
||||
MSE_loss1_list=[]
|
||||
MSE_loss2_list=[]
|
||||
MSE_loss3_list=[]
|
||||
|
||||
class Joint_Monitoring(keras.Model):
|
||||
|
||||
|
|
@ -28,12 +28,13 @@ class Joint_Monitoring(keras.Model):
|
|||
super(Joint_Monitoring, self).__init__()
|
||||
# step one
|
||||
self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5)
|
||||
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||
# self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME')
|
||||
self.upsample1 = tf.keras.layers.UpSampling1D(size=2)
|
||||
|
||||
self.DACU2 = DynamicChannelAttention()
|
||||
self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3)
|
||||
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME')
|
||||
self.upsample2 = tf.keras.layers.UpSampling1D(size=2)
|
||||
|
||||
self.DACU3 = DynamicChannelAttention()
|
||||
|
|
@ -255,6 +256,11 @@ class Joint_Monitoring(keras.Model):
|
|||
print("MSE_loss1:", MSE_loss1.numpy())
|
||||
print("MSE_loss2:", MSE_loss2.numpy())
|
||||
print("MSE_loss3:", MSE_loss3.numpy())
|
||||
|
||||
# MSE_loss1_list.append(MSE_loss1.numpy())
|
||||
# MSE_loss2_list.append(MSE_loss2.numpy())
|
||||
# MSE_loss3_list.append(MSE_loss3.numpy())
|
||||
|
||||
loss = MSE_loss1 + MSE_loss2 + MSE_loss3
|
||||
Accuracy_num = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
||||
from model.Dynamic_channelAttention.Dynamic_channelAttention import DynamicChannelAttention, DynamicPooling
|
||||
from model.ChannelAttention.Dynamic_channelAttention import DynamicChannelAttention, DynamicPooling
|
||||
from condition_monitoring.data_deal import loadData
|
||||
from model.LossFunction.smooth_L1_Loss import SmoothL1Loss
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,417 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 10:53
|
||||
@Usage : 即插即用的self_attention模块
|
||||
@Desc :
|
||||
'''
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import Model, layers, initializers
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PatchEmbed(layers.Layer):
|
||||
"""
|
||||
2D Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, embed_dim=768):
|
||||
super(PatchEmbed, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_size=patch_size
|
||||
self.img_size = (img_size, img_size)
|
||||
# 将图片划分成 img_size // patch_size行img_size // patch_size列的网格
|
||||
self.grid_size = (img_size // patch_size, img_size // patch_size)
|
||||
# 行列相乘,可以得到patch的数目,即14*14=196,到时候flatten层的输入即是196维
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
# con2d是16*16,然后步距也是16,即可将图片划分
|
||||
self.proj = layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
|
||||
strides=patch_size, padding='SAME',
|
||||
kernel_initializer=initializers.LecunNormal(),
|
||||
bias_initializer=initializers.Zeros())
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'img_size': self.img_size[0],
|
||||
'patch_size': self.patch_size,
|
||||
'embed_dim': self.embed_dim
|
||||
}
|
||||
)
|
||||
base_config = super(PatchEmbed, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
# 正向传播过程
|
||||
def call(self, inputs, **kwargs):
|
||||
# B-Batch的维度,H-高度, W-宽度, C-channel的维度
|
||||
B, H, W, C = inputs.shape
|
||||
# 设置的高和宽要和上面的保持一致,如果不一致会报错
|
||||
# 这里和CNN模型不一样,CNN模型会有一个全局池化层来对图片大小进行一个池化,所以可以不限制图片的具体大小
|
||||
# 而在transformer中是需要加上一个position embedding的,而这个层就需要知道图片的具体维度来设置position embedding的大小
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
# 将input传入设置好的conv2d,对其进行输出
|
||||
x = self.proj(inputs)
|
||||
# [B, H, W, C] -> [B, H*W, C]
|
||||
# reshape成可以进入flatten的形式
|
||||
x = tf.reshape(x, [B, self.num_patches, self.embed_dim])
|
||||
return x
|
||||
|
||||
|
||||
class ConcatClassTokenAddPosEmbed(layers.Layer):
|
||||
def __init__(self, embed_dim=768, num_patches=196, name=None):
|
||||
super(ConcatClassTokenAddPosEmbed, self).__init__(name=name)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_patches = num_patches
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'num_patches': self.num_patches,
|
||||
'embed_dim': self.embed_dim
|
||||
}
|
||||
)
|
||||
base_config = super(ConcatClassTokenAddPosEmbed, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def build(self, input_shape):
|
||||
# 创建两个可训练的参数weight,对应于cls_token,pos_embed的weight
|
||||
# shape第一维是batch维度,后面分别是1*768和197*768,trainable表示是可训练的参数
|
||||
self.cls_token = self.add_weight(name="cls",
|
||||
shape=[1, 1, self.embed_dim],
|
||||
initializer=initializers.Zeros(),
|
||||
trainable=True,
|
||||
dtype=tf.float32)
|
||||
self.pos_embed = self.add_weight(name="pos_embed",
|
||||
shape=[1, self.num_patches + 1, self.embed_dim],
|
||||
initializer=initializers.RandomNormal(stddev=0.02),
|
||||
trainable=True,
|
||||
dtype=tf.float32)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
# 需要获取一下batch的,因为输入图片的时候不是一张图片,而是将图片打包成一个个batch同时输入进去
|
||||
batch_size, _, _ = inputs.shape
|
||||
|
||||
# [1, 1, 768] -> [B, 1, 768]
|
||||
# 把cls_token复制B份进行拼接,broadcast_to方法,将cls_token进行一下广播,在batch维度就可以变成batch维度
|
||||
cls_token = tf.broadcast_to(self.cls_token, shape=[batch_size, 1, self.embed_dim])
|
||||
# 与input进行拼接
|
||||
x = tf.concat([cls_token, inputs], axis=1) # [B, 197, 768]
|
||||
# 加上位置编码
|
||||
x = x + self.pos_embed
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(layers.Layer):
|
||||
# 定义两个权重初始化方法,方便后续调用
|
||||
k_ini = initializers.GlorotUniform()
|
||||
b_ini = initializers.Zeros()
|
||||
|
||||
# dim是前面的embed的dimension,num_heads多头注意力的头数
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop_ratio=0.,
|
||||
proj_drop_ratio=0.,
|
||||
name=None):
|
||||
super(SelfAttention, self).__init__(name=name)
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_scale = qk_scale
|
||||
self.attn_drop_ratio = attn_drop_ratio
|
||||
self.proj_drop_ratio = proj_drop_ratio
|
||||
|
||||
|
||||
|
||||
|
||||
# 每一个head的dimension=输入的dimension/num_heads
|
||||
head_dim = dim // num_heads
|
||||
# 做softmax时会除一个sqrt(dk),这里的scale就是那个sqrt(dk)——缩放因子
|
||||
# 如果传入了qk_scale就用传入的,如果没传入就用sqrt(head_dim)
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
# 在有的做法中,生成QKV三个矩阵时会用三个全连接层
|
||||
# 这里用3*dim维的大全连接层,可以达到一样的效果
|
||||
self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
|
||||
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
|
||||
self.attn_drop = layers.Dropout(attn_drop_ratio)
|
||||
# 这里的全连接层是生成Wo矩阵,将得到的b进一步拼接
|
||||
# 由于这里multi-head self-attention模块的输入输出维度是一样的,所以这里的节点个数是dim
|
||||
self.proj = layers.Dense(dim, name="out",
|
||||
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
|
||||
self.proj_drop = layers.Dropout(proj_drop_ratio)
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'dim': self.dim,
|
||||
'num_heads': self.num_heads,
|
||||
'qkv_bias': self.qkv_bias,
|
||||
'qk_scale': self.qk_scale,
|
||||
'attn_drop_ratio': self.attn_drop_ratio,
|
||||
'proj_drop_ratio': self.proj_drop_ratio,
|
||||
}
|
||||
)
|
||||
base_config = super(SelfAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
# 由于进入self-attention的时候就已经将展平过了,所以这里的inputs.shape只有三个维度
|
||||
# [batch_size, num_patches + 1, total_embed_dim]
|
||||
B, N, C = inputs.shape
|
||||
# B, N, C = tf.shape(inputs)[0],tf.shape(inputs)[1],tf.shape(inputs)[2]
|
||||
|
||||
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
|
||||
qkv = self.qkv(inputs)
|
||||
# 分成三份,分别对应qkv,C // self.num_heads得到每一个head的维度
|
||||
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
|
||||
# qkv = tf.reshape(qkv, shape=[B, N, 3, self.num_heads, C // self.num_heads])
|
||||
qkv = tf.reshape(qkv, shape=[-1, N, 3, self.num_heads, C // self.num_heads])
|
||||
# qkv = tf.keras.layers.Reshape(target_shape=[B, N, 3, self.num_heads, C // self.num_heads])(qkv)
|
||||
# 用transpose方法,来调整一下维度的顺序,[2, 0, 3, 1, 4]表示调换之后的顺序
|
||||
# transpose: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
||||
qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
|
||||
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
|
||||
# 这里的矩阵相乘实际上指的是矩阵的最后两个维度相乘,而b的转置(transpose_b)
|
||||
# 实际上是[batch_size, num_heads, embed_dim_per_head, num_patches + 1]
|
||||
# multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
|
||||
attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale
|
||||
attn = tf.nn.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn, training=training)
|
||||
|
||||
# 与v相乘得到b
|
||||
# multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
||||
x = tf.matmul(attn, v)
|
||||
# 再用transpose调换一下顺序
|
||||
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
|
||||
x = tf.transpose(x, [0, 2, 1, 3])
|
||||
# reshape: -> [batch_size, num_patches + 1, total_embed_dim]
|
||||
# x = tf.reshape(x, [B, N, C])
|
||||
x = tf.reshape(x, [-1, N, C])
|
||||
|
||||
# 与Wo相乘,进一步融合得到输出
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x, training=training)
|
||||
return x
|
||||
|
||||
|
||||
class MLP(layers.Layer):
|
||||
"""
|
||||
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
"""
|
||||
# 定义两个权重初始化方法
|
||||
k_ini = initializers.GlorotUniform()
|
||||
b_ini = initializers.RandomNormal(stddev=1e-6)
|
||||
|
||||
# in_deatures表示输入MLP模块对应的dimension,mlp_ratio=4.0表示翻四倍
|
||||
def __init__(self, in_features, mlp_ratio=4.0, drop_rate=0., name=None):
|
||||
super(MLP, self).__init__(name=name)
|
||||
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
self.fc1 = layers.Dense(int(in_features * mlp_ratio), name="Dense_0",
|
||||
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
|
||||
self.act = layers.Activation("relu")
|
||||
self.fc2 = layers.Dense(in_features, name="Dense_1",
|
||||
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
|
||||
self.drop = layers.Dropout(drop_rate)
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'in_features': self.in_features,
|
||||
'mlp_ratio': self.mlp_ratio,
|
||||
'drop_rate': self.drop_rate,
|
||||
}
|
||||
)
|
||||
base_config = super(MLP, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
x = self.fc1(inputs)
|
||||
x = self.act(x)
|
||||
x = self.drop(x, training=training)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x, training=training)
|
||||
return x
|
||||
|
||||
|
||||
class Block(layers.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_ratio=0.,
|
||||
attn_drop_ratio=0.,
|
||||
drop_path_ratio=0.,
|
||||
name=None):
|
||||
super(Block, self).__init__(name=name)
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_scale = qk_scale
|
||||
self.drop_ratio = drop_ratio
|
||||
self.attn_drop_ratio = attn_drop_ratio
|
||||
self.drop_path_ratio = drop_path_ratio
|
||||
|
||||
|
||||
|
||||
|
||||
# LayerNormalization来进行正则化
|
||||
self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_0")
|
||||
# 调用Attention类,来实现MultiHeadAttention
|
||||
self.attn = SelfAttention(dim, num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio,
|
||||
name="MultiHeadAttention")
|
||||
# 这里所一个判断,如果drop_path_ratio > 0.就构建droppath方法
|
||||
# 如果drop_path_ratio < 0,就用linear,即输入是什么输出就是什么,不作操作
|
||||
# droppath方式的实现Dropout+noise_shape=(None, 1, 1)就可以实现droppath方法
|
||||
# 第一个None表示的是batch维度,第一个1表示的是num_patches+1,最后一个1表示的是embed_dimension
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = layers.Dropout(rate=drop_path_ratio, noise_shape=(None, 1, 1)) if drop_path_ratio > 0. \
|
||||
else layers.Activation("linear")
|
||||
self.norm2 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_1")
|
||||
self.mlp = MLP(dim, drop_rate=drop_ratio, name="MlpBlock")
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'dim': self.dim,
|
||||
'num_heads': self.num_heads,
|
||||
'qkv_bias': self.qkv_bias,
|
||||
'qk_scale': self.qk_scale,
|
||||
'drop_ratio': self.drop_ratio,
|
||||
'attn_drop_ratio': self.attn_drop_ratio,
|
||||
'drop_path_ratio': self.drop_path_ratio,
|
||||
}
|
||||
)
|
||||
base_config = super(Block, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
# 对应于图中第一个加号以前的部分
|
||||
x = inputs + self.drop_path(self.attn(self.norm1(inputs)), training=training)
|
||||
# 对应于图中第一个加号以后的部分
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)), training=training)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(Model):
|
||||
|
||||
# depth表示的是重复encoder block的次数,num_heads表示的是在multi-head self-attention中head的个数
|
||||
# MLP block中有一个Pre_logist,这里指的是,当在较大的数据集上学习的时候Pre_logist就表示一个全连接层加上一个tanh激活函数
|
||||
# 当在较小的数据集上学习的时候,Pre_logist是没有的,而这里的representation_size表示的就是Pre_logist中全连接层的节点个数
|
||||
# num_classes表示分类的类数
|
||||
def __init__(self, img_size=224, patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, qkv_bias=True, qk_scale=None,
|
||||
drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0.,
|
||||
representation_size=None, num_classes=1000, name="ViT-B/16"):
|
||||
super(VisionTransformer, self).__init__(name=name)
|
||||
|
||||
self.img_size=img_size
|
||||
self.patch_size=patch_size
|
||||
self.num_classes = num_classes
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.depth = depth
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_scale = qk_scale
|
||||
self.drop_ratio = drop_ratio
|
||||
self.attn_drop_ratio = attn_drop_ratio
|
||||
self.drop_path_ratio = drop_path_ratio
|
||||
self.representation_size = representation_size
|
||||
|
||||
# 这里实例化了PatchEmbed类
|
||||
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
||||
# 将PatchEmbed类中的num_patches取出来赋值给num_patches
|
||||
num_patches = self.patch_embed.num_patches
|
||||
# 这里实例化了ConcatClassTokenAddPosEmbed类
|
||||
self.cls_token_pos_embed = ConcatClassTokenAddPosEmbed(embed_dim=embed_dim,
|
||||
num_patches=num_patches,
|
||||
name="cls_pos")
|
||||
|
||||
self.pos_drop = layers.Dropout(drop_ratio)
|
||||
|
||||
dpr = np.linspace(0., drop_path_ratio, depth) # stochastic depth decay rule
|
||||
# 用一个for循环重复Block模块
|
||||
# 在用droppath时的drop_path_ratio是由0慢慢递增到我们所指定的drop_path_ratio的
|
||||
# 所以我们在构建Block时,这里的drop_path_ratio时变化的,所以用 np.linspace方法创建一个等差数列来初始化drop_path_ratio
|
||||
self.blocks = [Block(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
|
||||
drop_path_ratio=dpr[i], name="encoderblock_{}".format(i))
|
||||
for i in range(depth)]
|
||||
|
||||
self.norm = layers.LayerNormalization(epsilon=1e-6, name="encoder_norm")
|
||||
|
||||
# 接下来,如果传入了representation_size,就构建一个全连接层,激活函数为tanh
|
||||
# 如果没有传入的话,就不做任何操作
|
||||
if representation_size:
|
||||
self.has_logits = True
|
||||
self.pre_logits = layers.Dense(representation_size, activation="tanh", name="pre_logits")
|
||||
else:
|
||||
self.has_logits = False
|
||||
self.pre_logits = layers.Activation("linear")
|
||||
|
||||
# 定义最后一个全连接层,节点个数就是我们的分类个数num_classes
|
||||
self.head = layers.Dense(num_classes, name="head", kernel_initializer=initializers.Zeros())
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'img_size': self.img_size,
|
||||
'patch_size': self.patch_size,
|
||||
'embed_dim': self.embed_dim,
|
||||
'depth': self.depth,
|
||||
'num_heads': self.num_heads,
|
||||
'qkv_bias': self.qkv_bias,
|
||||
'qk_scale': self.qk_scale,
|
||||
'drop_ratio': self.drop_ratio,
|
||||
'attn_drop_ratio': self.attn_drop_ratio,
|
||||
'drop_path_ratio': self.drop_path_ratio,
|
||||
'representation_size': self.representation_size,
|
||||
'num_classes': self.num_classes
|
||||
}
|
||||
)
|
||||
base_config = super(VisionTransformer, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
# [B, H, W, C] -> [B, num_patches, embed_dim]
|
||||
x = self.patch_embed(inputs) # [B, 196, 768]
|
||||
x = self.cls_token_pos_embed(x) # [B, 197, 768]
|
||||
x = self.pos_drop(x, training=training)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, training=training)
|
||||
|
||||
x = self.norm(x)
|
||||
# 这里是提取class_toke的输出,然后用切片的方式,而刚刚是将class_toke拼接在最前面的
|
||||
# 所以这里用切片的方式,去取class_toke的输出,并将它传递给pre_logits
|
||||
x = self.pre_logits(x[:, 0])
|
||||
# 最后传递给head
|
||||
x = self.head(x)
|
||||
# 为什么只用class_toke对应的输出,而不用每一个patches对应的输出呢?
|
||||
# 可以参考原文bird 网络
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/10/23 10:52
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
Loading…
Reference in New Issue