记一次spark mllib stackoverflow踩坑
以前做als相关的东西的时候,都是用的公司的内部工具居多,今天第一次用了下spark的mlib,拿了个几M的小数据集试了个水。。
结果一跑,我擦。。。居然stackoverflow了。。
源码如下:
from pyspark.mllib.recommendation import ALS from numpy import array from pyspark import SparkContext if __name__ == "__main__": # sc = SparkSession # .builder # .appName("PythonWordCount") # .getOrCreate() sc = SparkContext(appName="PythonWordCount") data = sc.textFile("CollaborativeFiltering.txt", 20) ratings = data.map(lambda line: [float(x) for x in line.split(" ")]).persist() rank = 10 n = 30 model = ALS.train(ratings, rank, n) testdata = ratings.map(lambda r: (int(r[0]), int(r[1]))) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist() MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() print "Mean Squared Error = " + str(MSE) ratesAndPreds.unpersist()
错误信息如下:
2017-11-24 17:15:23 [INFO] ShuffleMapStage 66 (flatMap at ALS.scala:1272) failed in Unknown s due to Job aborted due to stage failure: Task serialization failed: java.lang.StackOverflowError java.lang.StackOverflowError at java.io.ObjectOutputStream$BlockDataOutputStream.write(ObjectOutputStream.java:1841) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1534) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at scala.collection.immutable.$colon$colon.writeObject(List.scala:379) at sun.reflect.GeneratedMethodAccessor15.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1028) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
泪崩 + 泪崩 + 泪崩
再后来怀疑到了linage 是不是过长导致,遂google和请教大神
发现果然如此,spark在迭代计算的过程中,会导致linage剧烈变长,所需的栈空间也急剧上升,最终爆栈了。。
这类问题解决方法如下:
在代码中加入 sc.setCheckpointDir(path),显示指明checkpoint路径,问题便可得到解决。当然这也带来了一个问题,如果数据量变大,磁盘的IO变成为了瓶颈,这方面暂时没能解决,各位聚聚有更好的解决方案,欢迎联系我~
修改后代码如下:
from pyspark.mllib.recommendation import ALS from numpy import array from pyspark import SparkContext if __name__ == "__main__": # sc = SparkSession # .builder # .appName("PythonWordCount") # .getOrCreate() sc = SparkContext(appName="PythonWordCount") sc.setCheckpointDir("checkpoint") data = sc.textFile("CollaborativeFiltering.txt", 20) ratings = data.map(lambda line: [float(x) for x in line.split(" ")]).persist() rank = 10 n = 30 #ALS.setCheckpointInterval(2).setMaxIter(100).setRank(10).setAlpha(0.1) model = ALS.train(ratings, rank, n) testdata = ratings.map(lambda r: (int(r[0]), int(r[1]))) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions).persist() MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() print "Mean Squared Error = " + str(MSE) ratesAndPreds.unpersist()
声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。
- 上一篇: setw()使用方法
- 下一篇: C++字符串与十六进制转换