scala模式匹配的一个问题


最近开始用scala做一些工作。

scala和java可以说同源同种,而我上一次写java程序可能要追溯到快十年前了。平常习惯使用弱类型语言,突然间切到scala,还有点不太适应。最近碰到的这个小问题,就耗费掉我不少时间。


大概场景是有两份数据left和right,需要对他们做一次全连接(fullOuterJoin),然后对于得到的结果,优先选择left中的数据,left中的数据不存在则选择right中的数据。示意如下:

1
2
3
4
5
6
7
8
9
10
11
scala> val left = sc.makeRDD(Seq(("1","2","LEFT12"),("2","3","LEFT23")))
left: org.apache.spark.rdd.RDD[(String, String, String)] = ParallelCollectionRDD[0] at makeRDD at <console>:24

scala> val right = sc.makeRDD(Seq(("1","2","RIGHT12"),("3","4","RIGHT34")))
right: org.apache.spark.rdd.RDD[(String, String, String)] = ParallelCollectionRDD[1] at makeRDD at <console>:24

scala> val leftmap = left.map(x => (x._1, x._2) -> x._3)
leftmap: org.apache.spark.rdd.RDD[((String, String), String)] = MapPartitionsRDD[2] at map at <console>:26

scala> val rightmap = right.map(x => (x._1, x._2) -> x._3)
rightmap: org.apache.spark.rdd.RDD[((String, String), String)] = MapPartitionsRDD[3] at map at <console>:26

接下进行fullOuterJoin,查看文档后得到,fullOuterJoin返回的结果是两个Option组成的元组。感谢scala的模式匹配,可以避免写一大堆if else,以下代码看上去很美好:

1
2
3
4
5
6
7
8
9
10
11
scala> val total = leftmap.fullOuterJoin(rightmap).map(kv =>
| kv._2 match {
| case (Some(s), None) => s
| case (None, Some(s)) => s
| case (Some(s), Some(t)) => s
| case _ => None
| })
total: org.apache.spark.rdd.RDD[java.io.Serializable] = MapPartitionsRDD[11] at map at <console>:32

scala> total.collect()
res3: Array[java.io.Serializable] = Array(RIGHT34, LEFT12, LEFT23)

但是返回值java.io.Serializable是什么鬼?这类型和我后续要把数据落地的接口不匹配,后续工作无法进行啊。

翻看过各种Stack Overflow上的问答,各种java.io.serializable的搜索结果,最终才弄明白,原因是scala的模式匹配语句match的返回值类型是各个case字句返回值类型的最近公共父类。这里的问题就出在case _ => None这一句,None类型和s的类型,导致最终scala找到java.io.serializable。

再做个试验验证下,比如这样的模式匹配:

1
2
3
4
5
6
7
scala> val total = leftmap.fullOuterJoin(rightmap).map(kv =>
| kv._2 match {
| case (Some(s), None) => 1
| case (None, Some(s)) => 2
| case (Some(s), Some(t)) => 1
| case _ => ""
| })

返回类型将是整型和字符串类型的公共父类Any类型:

1
total: org.apache.spark.rdd.RDD[Any] = MapPartitionsRDD[19] at map at <console>:32

解决方案也很简单,既然我知道我要的就是字符串,那么最后一个case子句返回空字符串:

1
2
3
4
5
6
7
8
9
10
11
scala> val total = leftmap.fullOuterJoin(rightmap).map(kv =>
| kv._2 match {
| case (Some(s), None) => s
| case (None, Some(s)) => s
| case (Some(s), Some(t)) => s
| case _ => ""
| })
total: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[15] at map at <console>:32

scala> total.collect()
res4: Array[String] = Array(RIGHT34, LEFT12, LEFT23)

对于scala/spark,我还是萌新一枚,有什么说的不准确的地方,也欢迎各位大佬留言批评指正~


推荐阅读:

Python协程:从yield/send到async/await/
打通Python和C++
待业青年

转载请注明出处: http://blog.guoyb.com/2017/12/09/scala-match/

欢迎使用微信扫描下方二维码,关注我的微信公众号TechTalking,技术·生活·思考:
后端技术小黑屋

Comments