Pydeequ - Serialize Lambda function
0xbadidea opened this issue · 2 comments
I'm not reporting a bug, just looking for a workaround and I'm hoping someone can help!
I'm trying to call deequ's rowLevelResultsAsDataFrame function from pydeequ. Things work fine but as soon as I add a lambda function to the Verification Suite checks, I start getting serialization errors.
pyspark==3.5
pydeequ==1.4.0
jar: com.amazon.deequ:deequ:2.0.7-spark-3.5
24/09/01 17:01:54 WARN SerializationDebugger: Exception in serialization debugger
py4j.Py4JException: An exception was raised by the Python Proxy. Return Message: Traceback (most recent call last):
File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/clientserver.py", line 617, in _call_proxy
return_value = getattr(self.pool[obj_id], method)(*params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ScalaFunction1' object has no attribute 'hashCode'
A working snippet is provided below. The code runs fine if we remove the check hasMin("b", lambda x: x == 0)
.
from pyspark.sql import SparkSession
from pydeequ.verification import VerificationResult, VerificationSuite
from pydeequ.checks import CheckLevel, Check
from pyspark.sql import Row, DataFrame
import pydeequ
spark = (
SparkSession.builder.config("spark.jars.packages", pydeequ.deequ_maven_coord)
.config("spark.jars.excludes", pydeequ.f2j_maven_coord)
.getOrCreate()
)
df = spark.sparkContext.parallelize(
[Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]
).toDF()
check = Check(spark, CheckLevel.Error, "Example Check")
verification_result = (
VerificationSuite(spark)
.onData(df)
.addCheck(
check.hasMin("b", lambda x: x == 0).isComplete("c")
.isUnique("a")
.isContainedIn("a", ["foo", "bar", "baz"])
.isNonNegative("b")
)
.run()
)
checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, verification_result)
java_result_df = (
spark._jvm.com.amazon.deequ.VerificationResult.rowLevelResultsAsDataFrame(
spark._jsparkSession, verification_result.verificationRun, df._jdf
)
)
result_df = DataFrame(java_result_df, spark)
result_df.show()
Can someone please provide a workaround?
Here's the full stack trace for the error:
py4j.Py4JException: An exception was raised by the Python Proxy. Return Message: Traceback (most recent call last):
File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/clientserver.py", line 617, in _call_proxy
return_value = getattr(self.pool[obj_id], method)(*params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ScalaFunction1' object has no attribute 'hashCode'
at py4j.Protocol.getReturnValue(Protocol.java:476)
at py4j.reflection.PythonProxyHandler.invoke(PythonProxyHandler.java:108)
at jdk.proxy3/jdk.proxy3.$Proxy34.hashCode(Unknown Source)
at scala.collection.mutable.FlatHashTable.findElemImpl(FlatHashTable.scala:131)
at scala.collection.mutable.FlatHashTable.containsElem(FlatHashTable.scala:126)
at scala.collection.mutable.FlatHashTable.containsElem$(FlatHashTable.scala:125)
at scala.collection.mutable.HashSet.containsElem(HashSet.scala:41)
at scala.collection.mutable.HashSet.contains(HashSet.scala:58)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:87)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
at org.apache.spark.serializer.SerializationDebugger$.find(SerializationDebugger.scala:67)
at org.apache.spark.serializer.SerializationDebugger$.improveException(SerializationDebugger.scala:41)
at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:49)
at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:115)
at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:441)
at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:416)
at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:163)
at org.apache.spark.SparkContext.clean(SparkContext.scala:2674)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$1(RDD.scala:904)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:903)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:759)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191)
at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:364)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:498)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4344)
at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3326)
at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334)
at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332)
at org.apache.spark.sql.Dataset.head(Dataset.scala:3326)
at org.apache.spark.sql.Dataset.take(Dataset.scala:3549)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:569)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
at java.base/java.lang.Thread.run(Thread.java:840)
Traceback (most recent call last):
File "/workspaces/cubie/cubie/transformers/temp.py", line 39, in <module>
result_df.show()
File "/opt/conda/envs/dev/lib/python3.12/site-packages/pyspark/sql/dataframe.py", line 959, in show
print(self._jdf.showString(n, 20, vertical))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/java_gateway.py", line 1322, in __call__
return_value = get_return_value(
^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/dev/lib/python3.12/site-packages/pyspark/errors/exceptions/captured.py", line 179, in deco
return f(*a, **kw)
^^^^^^^^^^^
File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/protocol.py", line 326, in get_return_value
raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling o82.showString.
: org.apache.spark.SparkException: Task not serializable
at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:444)
at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:416)
at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:163)
at org.apache.spark.SparkContext.clean(SparkContext.scala:2674)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$1(RDD.scala:904)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:903)
at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:759)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191)
at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:364)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:498)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4344)
at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3326)
at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334)
at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332)
at org.apache.spark.sql.Dataset.head(Dataset.scala:3326)
at org.apache.spark.sql.Dataset.take(Dataset.scala:3549)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:569)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.io.NotSerializableException: py4j.reflection.PythonProxyHandler
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
at java.base/java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:350)
at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:46)
at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:115)
at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:441)
... 45 more
@0xbadidea I think the root of the problem is there and I do not see a way how it can be fixed because this py4j.reflection.PythonProxyHandler
is not serializable at all (a good explanation in the linked py4j
discussion). I tried to add scala.Serializable
to list of classes that are implemented in scala_utils.ScalaFunction1
but it did not help because the problem is in py4j
itself.
Thanks @SemyonSinchenko! Appreciate your help. Guess there's no straightforrward way to implement this method in pydeequ.