用户画像example更新
This commit is contained in:
parent
ecbc11a572
commit
e89c402af8
|
|
@ -9,3 +9,6 @@ mysql.password=123456
|
||||||
|
|
||||||
# clickhouseÅäÖÃ
|
# clickhouseÅäÖÃ
|
||||||
clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
|
clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
|
||||||
|
|
||||||
|
# Ä£Ðͱ£´æÎ»ÖÃ
|
||||||
|
save-model.path=hdfs://Ding202:8020/user_profile/train_model/busi_gender
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ package com.atguigu.userprofile.ml.pipline
|
||||||
|
|
||||||
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
|
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel, Transformer}
|
import org.apache.spark.ml.{Pipeline, PipelineModel, Transformer}
|
||||||
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer}
|
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorAssembler, VectorIndexer}
|
||||||
|
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -136,6 +138,10 @@ class MyPipeLine {
|
||||||
.setFeaturesCol("feature_index")
|
.setFeaturesCol("feature_index")
|
||||||
.setPredictionCol("prediction_col")
|
.setPredictionCol("prediction_col")
|
||||||
.setImpurity("gini") //使用信息熵还是gini
|
.setImpurity("gini") //使用信息熵还是gini
|
||||||
|
.setMinInfoGain(minInfoGain)
|
||||||
|
.setMaxBins(maxBins)
|
||||||
|
.setMaxDepth(maxDepth)
|
||||||
|
.setMinInstancesPerNode(minInstancesPerNode)
|
||||||
|
|
||||||
classifier
|
classifier
|
||||||
}
|
}
|
||||||
|
|
@ -175,6 +181,60 @@ class MyPipeLine {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//把预测列 的矢量值转换为原值
|
||||||
|
def convertOrigin(predictedDataFrame:DataFrame): DataFrame ={
|
||||||
|
|
||||||
|
//找一号助理要矢量值与原值之间的对应关系
|
||||||
|
val transformer: Transformer = pipelineModel.stages(0)
|
||||||
|
val stringIndexerModel: StringIndexerModel = transformer.asInstanceOf[StringIndexerModel]
|
||||||
|
|
||||||
|
//定义一个转换器
|
||||||
|
val indexToString = new IndexToString()
|
||||||
|
//用转换器转换数据
|
||||||
|
indexToString.setInputCol("prediction_col").setOutputCol("prediction_origin").setLabels(stringIndexerModel.labels)
|
||||||
|
|
||||||
|
val convertedDataFrame = indexToString.transform(predictedDataFrame)
|
||||||
|
convertedDataFrame
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//打印评估报告 // 总准确率 // 各个选项的 召回率 和精确率
|
||||||
|
def printEvaluateReport(predictedDataFrame:DataFrame):Unit ={
|
||||||
|
//获取预测列
|
||||||
|
val predictAndLabelRDD: RDD[(Double, Double)] = predictedDataFrame.rdd.map {
|
||||||
|
row => {
|
||||||
|
val predictValue: Double = row.getAs[Double]("prediction_col")
|
||||||
|
val labelValue: Double = row.getAs[Double]("label_index")
|
||||||
|
(predictValue, labelValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val metrics = new MulticlassMetrics(predictAndLabelRDD)
|
||||||
|
|
||||||
|
println("总准确率:"+metrics.accuracy)
|
||||||
|
|
||||||
|
metrics.labels.foreach(
|
||||||
|
label =>{
|
||||||
|
println(s"矢量值为:$label 的精确率: ${metrics.precision(label)} ")
|
||||||
|
println(s"矢量值为:$label 的召回率: ${metrics.recall(label)} ")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//把生成的模型存储到指定的位置 ->hdfs
|
||||||
|
def saveModel(path:String) :Unit ={
|
||||||
|
pipelineModel.write.overwrite().save(path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
package com.atguigu.userprofile.ml.train
|
||||||
|
|
||||||
|
import java.util.Properties
|
||||||
|
|
||||||
|
import com.atguigu.userprofile.common.utils.MyPropertiesUtil
|
||||||
|
import com.atguigu.userprofile.ml.pipline.MyPipeLine
|
||||||
|
import org.apache.spark.SparkConf
|
||||||
|
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 3、 取完整的特征+label标签 SQL
|
||||||
|
* 4、定义流水线
|
||||||
|
* 5、 获得数据,把数据投入流水线 训练
|
||||||
|
* 6、 模拟预测 评估 (评分低:特征? 算法? 参数?)
|
||||||
|
* 7、 如果模型达到要求 , 把模型保存起来 hdfs
|
||||||
|
*/
|
||||||
|
object BusiGenderTrain {
|
||||||
|
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
|
||||||
|
//0 创建spark环境
|
||||||
|
val sparkConf: SparkConf = new SparkConf().setAppName("stud_gender_train_app").setMaster("local[*]")
|
||||||
|
val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
|
||||||
|
|
||||||
|
//3 取完整的特征+label标签 SQL
|
||||||
|
println("取完整的特征 + label标签 SQL")
|
||||||
|
|
||||||
|
val selectSQL=
|
||||||
|
s"""
|
||||||
|
|with user_c1
|
||||||
|
| as (
|
||||||
|
| select user_id,category1_id,during_time
|
||||||
|
| from dwd_page_log pl join dim_sku_info si
|
||||||
|
| on pl.page_item = si.id
|
||||||
|
| where pl.page_id = 'good_detail' and pl.page_item_type = 'sku_id' and pl.dt = '2022-05-20' and si.dt = '2022-05-20'
|
||||||
|
| ),
|
||||||
|
| user_label
|
||||||
|
| as (
|
||||||
|
| select id,gender from dim_user_info where dt='9999-99-99' and gender<>''
|
||||||
|
| )
|
||||||
|
|select user_id,c1_rk1,c1_rk2,c1_rk3,male_dur,female_dur,user_label.gender
|
||||||
|
|from
|
||||||
|
| (
|
||||||
|
| select user_id,sum(if(rk = 1,category1_id,0)) c1_rk1,sum(if(rk = 2,category1_id,0)) c1_rk2,sum(if(rk = 3,category1_id,0)) c1_rk3,
|
||||||
|
| sum(if(category1_id in (2,3,6) ,during_time,0)) male_dur,sum(if(category1_id in (11,15,8) ,during_time,0)) female_dur
|
||||||
|
| from
|
||||||
|
| (select user_id ,category1_id,ct,during_time,
|
||||||
|
| row_number() over( partition by user_id order by ct desc) rk
|
||||||
|
| from
|
||||||
|
| (
|
||||||
|
| select user_id ,category1_id,count(*) ct,sum(during_time) during_time
|
||||||
|
| from user_c1
|
||||||
|
| group by user_id,category1_id
|
||||||
|
| order by user_id,category1_id
|
||||||
|
| ) user_c1_ct
|
||||||
|
| order by user_id,category1_id) user_rk
|
||||||
|
| group by user_id
|
||||||
|
| ) user_feature join user_label on user_feature.user_id=user_label.id
|
||||||
|
|order by user_id
|
||||||
|
|""".stripMargin
|
||||||
|
|
||||||
|
sparkSession.sql(s"use gmall")
|
||||||
|
val dataFrame: DataFrame = sparkSession.sql(selectSQL)
|
||||||
|
|
||||||
|
println("数据拆分")
|
||||||
|
|
||||||
|
//把数据拆分
|
||||||
|
val Array(trainDF,testDF)=dataFrame.randomSplit(Array(0.8,0.2))
|
||||||
|
|
||||||
|
println("定义流水线")
|
||||||
|
|
||||||
|
//4 定义流水线
|
||||||
|
val myPipeLine: MyPipeLine = new MyPipeLine()
|
||||||
|
.setLabelColName("gender")
|
||||||
|
.setFeatureColNames(Array(
|
||||||
|
"c1_rk1","c1_rk2","c1_rk3","male_dur","female_dur"
|
||||||
|
))
|
||||||
|
.setMaxCategories(20) //鉴别谁是连续,谁是离散
|
||||||
|
.setMaxDepth(6)
|
||||||
|
.setMinInfoGain(0.03)
|
||||||
|
.setMaxBins(32)
|
||||||
|
.setMinInstancesPerNode(3)
|
||||||
|
.init()
|
||||||
|
|
||||||
|
println("进行训练")
|
||||||
|
|
||||||
|
|
||||||
|
//5 获得数据,把数据投入流水线 训练
|
||||||
|
myPipeLine.train(trainDF)
|
||||||
|
|
||||||
|
//打印决策树和特征权重
|
||||||
|
myPipeLine.printDecisionTree()
|
||||||
|
myPipeLine.printFeatureWeights()
|
||||||
|
|
||||||
|
|
||||||
|
println("进行预测")
|
||||||
|
//6 模拟预测 评估 (评分低:特征? 算法? 参数?)
|
||||||
|
val predictedDataFrame = myPipeLine.predict(testDF)
|
||||||
|
//展示预测结果
|
||||||
|
predictedDataFrame.show(100,false)
|
||||||
|
//打印评估报告
|
||||||
|
myPipeLine.printEvaluateReport(predictedDataFrame)
|
||||||
|
|
||||||
|
|
||||||
|
//7 存储模型
|
||||||
|
val properties: Properties = MyPropertiesUtil.load("config.properties")
|
||||||
|
val saveModelPath: String = properties.getProperty("save-model.path")
|
||||||
|
myPipeLine.saveModel(saveModelPath)
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -2,7 +2,7 @@ package com.atguigu.userprofile.ml.train
|
||||||
|
|
||||||
import com.atguigu.userprofile.ml.pipline.MyPipeLine
|
import com.atguigu.userprofile.ml.pipline.MyPipeLine
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||||
|
|
||||||
object StudGenderTrain {
|
object StudGenderTrain {
|
||||||
|
|
||||||
|
|
@ -49,6 +49,10 @@ object StudGenderTrain {
|
||||||
val myPineLine = new MyPipeLine().setLabelColName("gender")
|
val myPineLine = new MyPipeLine().setLabelColName("gender")
|
||||||
.setFeatureColNames(Array("hair","height","skirt","age"))
|
.setFeatureColNames(Array("hair","height","skirt","age"))
|
||||||
.setMaxCategories(5)
|
.setMaxCategories(5)
|
||||||
|
.setMaxDepth(6)
|
||||||
|
.setMinInfoGain(0.1)
|
||||||
|
.setMaxBins(32)
|
||||||
|
.setMinInstancesPerNode(4)
|
||||||
.init()
|
.init()
|
||||||
//4 进行训练
|
//4 进行训练
|
||||||
println("进行训练...")
|
println("进行训练...")
|
||||||
|
|
@ -65,6 +69,18 @@ object StudGenderTrain {
|
||||||
//false表示不切割
|
//false表示不切割
|
||||||
predictedDataFrame.show(100,false)
|
predictedDataFrame.show(100,false)
|
||||||
|
|
||||||
|
//7 把矢量预测结果转换为原始值
|
||||||
|
println("进行转换...")
|
||||||
|
val convertedDataFrame: DataFrame = myPineLine.convertOrigin(predictedDataFrame)
|
||||||
|
convertedDataFrame.show(100,false)
|
||||||
|
|
||||||
|
//8 打印评估报告 // 总准确率 // 各个选项的 召回率 和精确率
|
||||||
|
myPineLine.printEvaluateReport(convertedDataFrame)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue