From e89c402af8df00157422d300ae99ca7110882c08 Mon Sep 17 00:00:00 2001 From: markilue <745518019@qq.com> Date: Tue, 24 May 2022 18:42:41 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=94=BB=E5=83=8Fexample?= =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/resources/config.properties | 3 + .../userprofile/ml/pipline/MyPipeLine.scala | 62 +++++++++- .../ml/train/BusiGenderTrain.scala | 113 ++++++++++++++++++ .../ml/train/StudGenderTrain.scala | 18 ++- 4 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/BusiGenderTrain.scala diff --git a/Big_data_example/user_profile/machine-learning/src/main/resources/config.properties b/Big_data_example/user_profile/machine-learning/src/main/resources/config.properties index d76795a..265e144 100644 --- a/Big_data_example/user_profile/machine-learning/src/main/resources/config.properties +++ b/Big_data_example/user_profile/machine-learning/src/main/resources/config.properties @@ -9,3 +9,6 @@ mysql.password=123456 # clickhouse配置 clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224 + +# 模型保存位置 +save-model.path=hdfs://Ding202:8020/user_profile/train_model/busi_gender diff --git a/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/pipline/MyPipeLine.scala b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/pipline/MyPipeLine.scala index a3a6a67..2179de9 100644 --- a/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/pipline/MyPipeLine.scala +++ b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/pipline/MyPipeLine.scala @@ -2,7 +2,9 @@ package com.atguigu.userprofile.ml.pipline import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.{Pipeline, PipelineModel, Transformer} -import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer} +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorAssembler, VectorIndexer} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** @@ -136,6 +138,10 @@ class MyPipeLine { .setFeaturesCol("feature_index") .setPredictionCol("prediction_col") .setImpurity("gini") //浣跨敤淇℃伅鐔佃繕鏄痝ini + .setMinInfoGain(minInfoGain) + .setMaxBins(maxBins) + .setMaxDepth(maxDepth) + .setMinInstancesPerNode(minInstancesPerNode) classifier } @@ -175,6 +181,60 @@ class MyPipeLine { } + //鎶婇娴嬪垪 鐨勭煝閲忓艰浆鎹负鍘熷 + def convertOrigin(predictedDataFrame:DataFrame): DataFrame ={ + + //鎵句竴鍙峰姪鐞嗚鐭㈤噺鍊间笌鍘熷间箣闂寸殑瀵瑰簲鍏崇郴 + val transformer: Transformer = pipelineModel.stages(0) + val stringIndexerModel: StringIndexerModel = transformer.asInstanceOf[StringIndexerModel] + + //瀹氫箟涓涓浆鎹㈠櫒 + val indexToString = new IndexToString() + //鐢ㄨ浆鎹㈠櫒杞崲鏁版嵁 + indexToString.setInputCol("prediction_col").setOutputCol("prediction_origin").setLabels(stringIndexerModel.labels) + + val convertedDataFrame = indexToString.transform(predictedDataFrame) + convertedDataFrame + + + } + + //鎵撳嵃璇勪及鎶ュ憡 // 鎬诲噯纭巼 // 鍚勪釜閫夐」鐨 鍙洖鐜 鍜岀簿纭巼 + def printEvaluateReport(predictedDataFrame:DataFrame):Unit ={ + //鑾峰彇棰勬祴鍒 + val predictAndLabelRDD: RDD[(Double, Double)] = predictedDataFrame.rdd.map { + row => { + val predictValue: Double = row.getAs[Double]("prediction_col") + val labelValue: Double = row.getAs[Double]("label_index") + (predictValue, labelValue) + } + } + + val metrics = new MulticlassMetrics(predictAndLabelRDD) + + println("鎬诲噯纭巼:"+metrics.accuracy) + + metrics.labels.foreach( + label =>{ + println(s"鐭㈤噺鍊间负:$label 鐨勭簿纭巼: ${metrics.precision(label)} ") + println(s"鐭㈤噺鍊间负:$label 鐨勫彫鍥炵巼: ${metrics.recall(label)} ") + } + ) + + + + } + + + + //鎶婄敓鎴愮殑妯″瀷瀛樺偍鍒版寚瀹氱殑浣嶇疆 ->hdfs + def saveModel(path:String) :Unit ={ + pipelineModel.write.overwrite().save(path) + + + + } + diff --git a/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/BusiGenderTrain.scala b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/BusiGenderTrain.scala new file mode 100644 index 0000000..8ff3c7b --- /dev/null +++ b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/BusiGenderTrain.scala @@ -0,0 +1,113 @@ +package com.atguigu.userprofile.ml.train + +import java.util.Properties + +import com.atguigu.userprofile.common.utils.MyPropertiesUtil +import com.atguigu.userprofile.ml.pipline.MyPipeLine +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} + +/** + * 3銆 鍙栧畬鏁寸殑鐗瑰緛+label鏍囩 SQL + * 4銆佸畾涔夋祦姘寸嚎 + * 5銆 鑾峰緱鏁版嵁锛屾妸鏁版嵁鎶曞叆娴佹按绾 璁粌 + * 6銆 妯℃嫙棰勬祴 飪 璇勪及 锛堣瘎鍒嗕綆锛氱壒寰侊紵 绠楁硶锛 鍙傛暟锛燂級 + * 7銆 濡傛灉妯″瀷杈惧埌瑕佹眰 锛 鎶婃ā鍨嬩繚瀛樿捣鏉 hdfs + */ +object BusiGenderTrain { + + def main(args: Array[String]): Unit = { + + //0 鍒涘缓spark鐜 + val sparkConf: SparkConf = new SparkConf().setAppName("stud_gender_train_app").setMaster("local[*]") + val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate() + + //3 鍙栧畬鏁寸殑鐗瑰緛+label鏍囩 SQL + println("鍙栧畬鏁寸殑鐗瑰緛 + label鏍囩 SQL") + + val selectSQL= + s""" + |with user_c1 + | as ( + | select user_id,category1_id,during_time + | from dwd_page_log pl join dim_sku_info si + | on pl.page_item = si.id + | where pl.page_id = 'good_detail' and pl.page_item_type = 'sku_id' and pl.dt = '2022-05-20' and si.dt = '2022-05-20' + | ), + | user_label + | as ( + | select id,gender from dim_user_info where dt='9999-99-99' and gender<>'' + | ) + |select user_id,c1_rk1,c1_rk2,c1_rk3,male_dur,female_dur,user_label.gender + |from + | ( + | select user_id,sum(if(rk = 1,category1_id,0)) c1_rk1,sum(if(rk = 2,category1_id,0)) c1_rk2,sum(if(rk = 3,category1_id,0)) c1_rk3, + | sum(if(category1_id in (2,3,6) ,during_time,0)) male_dur,sum(if(category1_id in (11,15,8) ,during_time,0)) female_dur + | from + | (select user_id ,category1_id,ct,during_time, + | row_number() over( partition by user_id order by ct desc) rk + | from + | ( + | select user_id ,category1_id,count(*) ct,sum(during_time) during_time + | from user_c1 + | group by user_id,category1_id + | order by user_id,category1_id + | ) user_c1_ct + | order by user_id,category1_id) user_rk + | group by user_id + | ) user_feature join user_label on user_feature.user_id=user_label.id + |order by user_id + |""".stripMargin + + sparkSession.sql(s"use gmall") + val dataFrame: DataFrame = sparkSession.sql(selectSQL) + + println("鏁版嵁鎷嗗垎") + + //鎶婃暟鎹媶鍒 + val Array(trainDF,testDF)=dataFrame.randomSplit(Array(0.8,0.2)) + + println("瀹氫箟娴佹按绾") + + //4 瀹氫箟娴佹按绾 + val myPipeLine: MyPipeLine = new MyPipeLine() + .setLabelColName("gender") + .setFeatureColNames(Array( + "c1_rk1","c1_rk2","c1_rk3","male_dur","female_dur" + )) + .setMaxCategories(20) //閴村埆璋佹槸杩炵画锛岃皝鏄鏁 + .setMaxDepth(6) + .setMinInfoGain(0.03) + .setMaxBins(32) + .setMinInstancesPerNode(3) + .init() + + println("杩涜璁粌") + + + //5 鑾峰緱鏁版嵁锛屾妸鏁版嵁鎶曞叆娴佹按绾 璁粌 + myPipeLine.train(trainDF) + + //鎵撳嵃鍐崇瓥鏍戝拰鐗瑰緛鏉冮噸 + myPipeLine.printDecisionTree() + myPipeLine.printFeatureWeights() + + + println("杩涜棰勬祴") + //6 妯℃嫙棰勬祴 飪 璇勪及 锛堣瘎鍒嗕綆锛氱壒寰侊紵 绠楁硶锛 鍙傛暟锛燂級 + val predictedDataFrame = myPipeLine.predict(testDF) + //灞曠ず棰勬祴缁撴灉 + predictedDataFrame.show(100,false) + //鎵撳嵃璇勪及鎶ュ憡 + myPipeLine.printEvaluateReport(predictedDataFrame) + + + //7 瀛樺偍妯″瀷 + val properties: Properties = MyPropertiesUtil.load("config.properties") + val saveModelPath: String = properties.getProperty("save-model.path") + myPipeLine.saveModel(saveModelPath) + + + } + +} diff --git a/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/StudGenderTrain.scala b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/StudGenderTrain.scala index 17301fd..58511e5 100644 --- a/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/StudGenderTrain.scala +++ b/Big_data_example/user_profile/machine-learning/src/main/scala/com/atguigu/userprofile/ml/train/StudGenderTrain.scala @@ -2,7 +2,7 @@ package com.atguigu.userprofile.ml.train import com.atguigu.userprofile.ml.pipline.MyPipeLine import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, SparkSession} object StudGenderTrain { @@ -49,6 +49,10 @@ object StudGenderTrain { val myPineLine = new MyPipeLine().setLabelColName("gender") .setFeatureColNames(Array("hair","height","skirt","age")) .setMaxCategories(5) + .setMaxDepth(6) + .setMinInfoGain(0.1) + .setMaxBins(32) + .setMinInstancesPerNode(4) .init() //4 杩涜璁粌 println("杩涜璁粌...") @@ -65,6 +69,18 @@ object StudGenderTrain { //false琛ㄧず涓嶅垏鍓 predictedDataFrame.show(100,false) + //7 鎶婄煝閲忛娴嬬粨鏋滆浆鎹负鍘熷鍊 + println("杩涜杞崲...") + val convertedDataFrame: DataFrame = myPineLine.convertOrigin(predictedDataFrame) + convertedDataFrame.show(100,false) + + //8 鎵撳嵃璇勪及鎶ュ憡 // 鎬诲噯纭巼 // 鍚勪釜閫夐」鐨 鍙洖鐜 鍜岀簿纭巼 + myPineLine.printEvaluateReport(convertedDataFrame) + + + + + } }