import cn.doitedu.commons.util.SparkUtil
import org.apache.spark.ml.linalg
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
object KnnDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkUtil.getSparkSession("KNN")
import spark.implicits._
// 以开发计算余弦相似度的scala函数为例
val cossim = (v1: linalg.Vector, v2: linalg.Vector) => {
val v1arr = v1.toArray
val v2arr = v2.toArray
// 求v1的模平方
val m1: Double = v1arr.map(Math.pow(_, 2)).sum
// 求v2的模平方
val m2: Double = v2arr.map(Math.pow(_, 2)).sum
// 求v1和v2的点乘
val innerProduct: Double = v1arr.zip(v2arr).map(tp => tp._1 * tp._2).sum
innerProduct / Math.pow(m1 * m2, 0.5)
}
// 然后将scala函数放在sparksql的计算中使用
/**
* 方式1:将scala函数,注册成sql函数
*/
//注册
spark.udf.register("cossim", cossim)
//读取样本数据
val ds1: Dataset[String] = spark.read.textFile("userprofile/data/demo/a.txt")
//读取测试数据
val ds2: Dataset[String] = spark.read.textFile("userprofile/data/demo/b.txt")
//样本数据处理
val tp1: DataFrame = ds1.map(line => {
val arr: Array[String] = line.split(",")
val label: String = arr(0)
//取出特征值数据
val features: Array[Double] = arr.tail.map(_.toDouble)
//将特征值数据封装成向量
val vector: Vector = Vectors.dense(features)
(label, vector)
}).toDF("label", "vec")
//预测数据
val tp2: DataFrame = ds2.map(line2 => {
val arr2: Array[String] = line2.split(",")
val features: Array[Double] = arr2.map(_.toDouble)
val v2: Vector = Vectors.dense(features)
(9999, v2)
}).toDF("b_label", "b_vec")
//将两种数据join在一起
//有一个直接调用笛卡尔积的算子 ==> crossjoin 在此处可以取代join,参数中不用再附加连接条件
val joined: DataFrame = tp1.join(tp2, 'label < 'b_label)
joined.createTempView("joined")
val res1: DataFrame = spark.sql(
"""
|select
|label,
|b_vec,
|sim
|from
|(
| select
| label,
| b_vec,
| cossim(vec,b_vec) as sim,
| row_number() over(partition by b_vec order by cossim(vec,b_vec) desc) as rn
| from
| joined
|)t
|where rn <= 5
|
""".stripMargin)
res1.createTempView("res")
spark.sql(
"""
|select
|label,
|b_vec
|from
|(
|select
|label,
|b_vec,
|row_number() over(partition by label,b_vec order by count(1) desc) as rn
|from res
|group by label,b_vec
|)t
|where rn = 1
|
""".stripMargin)
.show()
spark.close()
}
}