Spark使用Java开发遇到的那些类型错误

Spark使用Java开发其实比较方便的,JAVA8的lambda表达式使得编写体验并不比Scala差很多,但是因为Spark本身使用Scala实现,导致使用Java开发的时候,也遇到不少的类型匹配问题。

本文列举出自己在工作开发中遇到的一些问题,供大家参考:

WrappedArray和Vector

报错信息为:Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to org.apache.spark.ml.linalg.Vector

当使用DataFrame打印Schema的时候,是这样的输出:

 |-- tag_weights: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- word_sims: array (nullable = true)
 |    |-- element: double (containsNull = true)

 

这时候如果Java用Vector接收,就会报这个错误,JAVA代码为:

spark.udf().register(
                "computeWeightSim",
                new UDF2<Vector, Vector, Double>() {
                    @Override
                    public Double call(Vector tag_weights, Vector word_sims) throws Exception {

解决办法是使用WrappedArray<Long>来接收,这是个scala的类型,可以用Iterator做遍历:

scala.collection.Iterator<Long> it1 = view_qipuids.iterator();
scala.collection.Iterator<Long> it2 = view_cnts.iterator();

Map<Long, Long> viewMap = new HashMap<>();
while (it1.hasNext() && it2.hasNext()) {
    viewMap.put(it1.next(), it2.next());
}

或者可以zip两个iterator进行计算:

new UDF2<WrappedArray<Double>, WrappedArray<Double>, Double>() {
    /**
     * 计算加权权重
     * @param tag_weights 加权
     * @param word_sims 计算结果目标
     * @return 加权权重
     * @throws Exception
     */
    @Override
    public Double call(WrappedArray<Double> tag_weights, WrappedArray<Double> word_sims) throws Exception {
        scala.collection.Iterator<Double> tag_weightsIter = tag_weights.iterator();
        scala.collection.Iterator<Double> word_simsIter = word_sims.iterator();

        scala.collection.Iterator<Tuple2<Double, Double>> zipIterator = tag_weightsIter.zip(word_simsIter);

        double totalWeight = 0.0;
        double fenziWeight = 0.0;
        while (zipIterator.hasNext()) {
            Tuple2<Double, Double> iterTuple = zipIterator.next();
            totalWeight += iterTuple._1;
            fenziWeight += iterTuple._1 * iterTuple._2;
        }

        if (totalWeight == 0.0) {
            return 0.0;
        } else {
            return fenziWeight / totalWeight;
        }
    }
}

 

详细内容见scala的文档:https://docs.scala-lang.org/overviews/collections/iterators.html

 

Spark使用JAVA编写自定义函数修改DataFrame

本文的代码涉及几个知识点,都是比较有用:

1、Spark用JAVA编写代码的方式;

2、Spark读取MySQL数据表,并且使用的是自定义SQL的方式,默认会读取整个表的;

3、Spark使用sql.functions的原有方法,给dataframe新增列、变更列;

4、Spark使用udf的自定义函数,给dataframe新增列、变更列;

/**
 * spark直接读取mysql
 */
private static Dataset<Row> queryMySQLData(SparkSession spark) {
    Properties properties = new Properties();
    properties.put("user", "root");
    properties.put("password", "12345678");
    properties.put("driver", "com.mysql.jdbc.Driver");
    // 可以写SQL语句查询数据结果
    return spark.read().jdbc(
            "jdbc:mysql://127.0.0.1:3306/test"
            , "(select id, name from tb_data) tsub",
            properties);
}

整个函数使用spark.read().jdbc读取mysql数据表,配置了mysql的user、passpord、driver,jdbcurl,以及可以通过sql语句执行数据查询,sql语句这里在spark源文档是table name,如果只设置table name,则会读取整个表,可以使用(select id, name from tb_data) tsub的方式读取SQL结果,注意的是这里必须给SQL语句设定一个标的别名。

以下是几种给dataframe添加新列、修改原有列的方法

方法1:使用functions中的函数,有一些局限性

// 方法1:使用functions中的函数,有一些局限性
inputData.withColumn("name_length_method1", functions.length(inputData.col("name")));

使用的是sql.functions里面的方法,里面支持了大部分的size、length等等方法,不过还是不够灵活,因为不支持就是不支持;

方法2:自定义注册udf,可以用JAVA代码写处理

可以先用spark.udf().register注册方法,然后使用functions.callUDF进行调用,其中自定义方法需要实现UDF1~UDF20的接口,分别代表传入不同的入参列:

// 方法2:自定义注册udf,可以用JAVA代码写处理
spark.udf().register(
        "getLength",
        new UDF1<String, Integer>() {
            @Override
            public Integer call(String s) throws Exception {
                return s.length();
            }
        },
        DataTypes.IntegerType);

inputData = inputData.withColumn(
        "name_length_method2",
        functions.callUDF("getLength",
                inputData.col("name"))
);

// 方法2.1:可以写UDF2~UDF20,就是把输入字段变成多个
spark.udf().register(
        "getLength2",
        new UDF2<Long, String, Long>() {

            @Override
            public Long call(Long aLong, String s) throws Exception {
                return aLong + s.length();
            }
        },
        DataTypes.LongType);

inputData = inputData.withColumn(
        "name_length_method3",
        functions.callUDF(
                "getLength2",
                inputData.col("id"),
                inputData.col("name"))
);

inputData.show(20, false);

代码地址见:github地址

tb_data的mysql表数据读取后的原始dataframe的schema:

root
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)

数据:

+---+-----------+
|id |name       |
+---+-----------+
|1  |name1      |
|2  |name22     |
|3  |name333    |
|4  |name4444   |
|5  |name55555  |
|6  |name666666 |
|7  |name7777777|
+---+-----------+

最终计算之后的数据输出:

+---+-----------+-------------------+-------------------+-------------------+
|id |name       |name_length_method1|name_length_method2|name_length_method3|
+---+-----------+-------------------+-------------------+-------------------+
|1  |name1      |5                  |5                  |6                  |
|2  |name22     |6                  |6                  |8                  |
|3  |name333    |7                  |7                  |10                 |
|4  |name4444   |8                  |8                  |12                 |
|5  |name55555  |9                  |9                  |14                 |
|6  |name666666 |10                 |10                 |16                 |
|7  |name7777777|11                 |11                 |18                 |
+---+-----------+-------------------+-------------------+-------------------+

 

PyCharm开发PySpark程序的配置和实例

对于PyCharm,需要作如下设置:
1、安装pyspark,它会自动安装py4j
2、在edit configuration中,add content root,选择spark下载包的python/pyspark/lib下的pyspark.zip和py4j.zip两个包;

代码实例:

from pyspark.sql import Row
from pyspark.sql import SparkSession

logFile = "file:///Users/peishuaishuai/tmp/sparktest.txt"  # Should be some file on your system
spark = SparkSession.builder.appName("SimpleApp").getOrCreate()

input = spark.read.text(logFile).rdd.map(
    lambda x: str(x[0]).split("\t")
).filter(
    lambda x: len(x) == 2
).map(
    lambda x: Row(name=x[0], grade=int(x[1]))
)

schemaData = spark.createDataFrame(input)
schemaData.createOrReplaceTempView("tb")

print(schemaData.count())
schemaData.printSchema()

datas = spark.sql("select name,sum(grade) from tb group by name").rdd.map(
    lambda x: "\t".join([x[0], str(x[1])])
)

datas.repartition(3).saveAsTextFile("file:///Users/peishuaishuai/tmp/sparktest_output")

spark.stop()

 

输入数据为:

name1	11
name2	12
name3	13
name4	14
name5	15
name1	16
name2	17
name3	18
name4	19
name5	20
name11	21
name12	22
name1	23
name2	24
name3	25
name4	26
name5	27
name18	28
name19	29
name20	30
name21	31
name1	32
name2	33
name3	34
name4	35
name5	36
name27	37
name28	38
name29	39
name1	40
name2	41
name3	42
name4	43

输出 print结果为:

33
root
 |-- grade: long (nullable = true)
 |-- name: string (nullable = true)

文件中内容为:

name3	132
name19	29
name2	127
name12	22
name11	21
name20	30
name28	38
name27	37
name5	98
name29	39
name21	31
name4	137
name1	122
name18	28

pyspark开发起来,有点问题就是当级联过多的时候,类型可能丢失,导致代码没有提示,这点很不爽。

其实对比了python、scala、java,我觉得编写大型的spark代码,用Java是最靠谱的,因为它强类型,代码提示很爽很直观。

 

Spark数据倾斜解决方法

1、避免shuffle,改reduce join为map join,适用于JOIN的时候有一个表是小表的情况,直接使用collect()获取小表的所有数据,然后brodcast,对大表进行MAP,MAP时直接提取broadcast的小表数据实现JOIN;

2、随机数的方案,对于聚合类操作,可以分步骤进行聚合,第一步,在原来的KEY后面加上随机数(比如1~10),然后进行聚合(比如SUM操作);第二步去掉KEY后面的随机数;第三部再次聚合(对应第一步的SUM),只适用于聚合类场景;

3、HIVE预处理的方案,如果已经有数据倾斜,则用HIVE预处理,然后将结果加载到SPARK中进行使用,适用于SPARK会频繁使用但是HIVE只会预计算一次的场景,用于即席查询比较多;

4、修改或者提升shuffle的并行度,使用repatition进行,比如原来每个节点处理10个KEY的数据,现在处理3个KEY的数据,虽然某些KEY仍然是热点,但是会缓解不少;

5、过滤掉发生倾斜的KEY,场景较少,可以用采样、预计算的方式计算出KEY的数量分布,然后过滤掉最多的KEY的数据即可;

6、分治法+空间浪费法,将A表中热点KEY的数据单独提取出来,对KEY加上随机前缀;然后将B表对应热点KEY的数据提取出来,重复加上所有的随机数KEY,然后这俩RDD关联,得到热点的结果RDD;对于A/B剩下的数据,按普通的进行JOIN,得到普通结果的RDD;然后将热点RDD和普通RDD进行UNION得到最终结果;

7、完全空间浪费法,对A表所有数据的KEY加随机前缀,对B表所有KEY做重复加上所有的随机前缀,然后做关联得到结果;

8、多种方法配合使用;

本文总结自:https://www.iteblog.com/archives/1671.html

本文地址:http://crazyant.net/2231.html