Spark使用JAVA编写自定义函数修改DataFrame

本文的代码涉及几个知识点,都是比较有用:

1、Spark用JAVA编写代码的方式;

2、Spark读取MySQL数据表,并且使用的是自定义SQL的方式,默认会读取整个表的;

3、Spark使用sql.functions的原有方法,给dataframe新增列、变更列;

4、Spark使用udf的自定义函数,给dataframe新增列、变更列;

/**
 * spark直接读取mysql
 */
private static Dataset<Row> queryMySQLData(SparkSession spark) {
    Properties properties = new Properties();
    properties.put("user", "root");
    properties.put("password", "12345678");
    properties.put("driver", "com.mysql.jdbc.Driver");
    // 可以写SQL语句查询数据结果
    return spark.read().jdbc(
            "jdbc:mysql://127.0.0.1:3306/test"
            , "(select id, name from tb_data) tsub",
            properties);
}

整个函数使用spark.read().jdbc读取mysql数据表,配置了mysql的user、passpord、driver,jdbcurl,以及可以通过sql语句执行数据查询,sql语句这里在spark源文档是table name,如果只设置table name,则会读取整个表,可以使用(select id, name from tb_data) tsub的方式读取SQL结果,注意的是这里必须给SQL语句设定一个标的别名。

以下是几种给dataframe添加新列、修改原有列的方法

方法1:使用functions中的函数,有一些局限性

// 方法1:使用functions中的函数,有一些局限性
inputData.withColumn("name_length_method1", functions.length(inputData.col("name")));

使用的是sql.functions里面的方法,里面支持了大部分的size、length等等方法,不过还是不够灵活,因为不支持就是不支持;

方法2:自定义注册udf,可以用JAVA代码写处理

可以先用spark.udf().register注册方法,然后使用functions.callUDF进行调用,其中自定义方法需要实现UDF1~UDF20的接口,分别代表传入不同的入参列:

// 方法2:自定义注册udf,可以用JAVA代码写处理
spark.udf().register(
        "getLength",
        new UDF1<String, Integer>() {
            @Override
            public Integer call(String s) throws Exception {
                return s.length();
            }
        },
        DataTypes.IntegerType);

inputData = inputData.withColumn(
        "name_length_method2",
        functions.callUDF("getLength",
                inputData.col("name"))
);

// 方法2.1:可以写UDF2~UDF20,就是把输入字段变成多个
spark.udf().register(
        "getLength2",
        new UDF2<Long, String, Long>() {

            @Override
            public Long call(Long aLong, String s) throws Exception {
                return aLong + s.length();
            }
        },
        DataTypes.LongType);

inputData = inputData.withColumn(
        "name_length_method3",
        functions.callUDF(
                "getLength2",
                inputData.col("id"),
                inputData.col("name"))
);

inputData.show(20, false);

代码地址见:github地址

tb_data的mysql表数据读取后的原始dataframe的schema:

root
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)

数据:

+---+-----------+
|id |name       |
+---+-----------+
|1  |name1      |
|2  |name22     |
|3  |name333    |
|4  |name4444   |
|5  |name55555  |
|6  |name666666 |
|7  |name7777777|
+---+-----------+

最终计算之后的数据输出:

+---+-----------+-------------------+-------------------+-------------------+
|id |name       |name_length_method1|name_length_method2|name_length_method3|
+---+-----------+-------------------+-------------------+-------------------+
|1  |name1      |5                  |5                  |6                  |
|2  |name22     |6                  |6                  |8                  |
|3  |name333    |7                  |7                  |10                 |
|4  |name4444   |8                  |8                  |12                 |
|5  |name55555  |9                  |9                  |14                 |
|6  |name666666 |10                 |10                 |16                 |
|7  |name7777777|11                 |11                 |18                 |
+---+-----------+-------------------+-------------------+-------------------+

 

相关推荐

Leave a Comment