awslabs/python-deequ

Pydeequ - Serialize Lambda function

0xbadidea opened this issue · 2 comments

I'm not reporting a bug, just looking for a workaround and I'm hoping someone can help!

I'm trying to call deequ's rowLevelResultsAsDataFrame function from pydeequ. Things work fine but as soon as I add a lambda function to the Verification Suite checks, I start getting serialization errors.

pyspark==3.5
pydeequ==1.4.0
jar: com.amazon.deequ:deequ:2.0.7-spark-3.5

24/09/01 17:01:54 WARN SerializationDebugger: Exception in serialization debugger
py4j.Py4JException: An exception was raised by the Python Proxy. Return Message: Traceback (most recent call last):
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/clientserver.py", line 617, in _call_proxy
    return_value = getattr(self.pool[obj_id], method)(*params)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ScalaFunction1' object has no attribute 'hashCode'

A working snippet is provided below. The code runs fine if we remove the check hasMin("b", lambda x: x == 0).

from pyspark.sql import SparkSession
from pydeequ.verification import VerificationResult, VerificationSuite
from pydeequ.checks import CheckLevel, Check
from pyspark.sql import Row, DataFrame
import pydeequ

spark = (
    SparkSession.builder.config("spark.jars.packages", pydeequ.deequ_maven_coord)
    .config("spark.jars.excludes", pydeequ.f2j_maven_coord)
    .getOrCreate()
)

df = spark.sparkContext.parallelize(
    [Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]
).toDF()

check = Check(spark, CheckLevel.Error, "Example Check")

verification_result = (
    VerificationSuite(spark)
    .onData(df)
    .addCheck(
        check.hasMin("b", lambda x: x == 0).isComplete("c")
        .isUnique("a")
        .isContainedIn("a", ["foo", "bar", "baz"])
        .isNonNegative("b")
    )
    .run()
)
checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, verification_result)
java_result_df = (
    spark._jvm.com.amazon.deequ.VerificationResult.rowLevelResultsAsDataFrame(
        spark._jsparkSession, verification_result.verificationRun, df._jdf
    )
)

result_df = DataFrame(java_result_df, spark)
result_df.show()

Can someone please provide a workaround?

Here's the full stack trace for the error:

py4j.Py4JException: An exception was raised by the Python Proxy. Return Message: Traceback (most recent call last):
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/clientserver.py", line 617, in _call_proxy
    return_value = getattr(self.pool[obj_id], method)(*params)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ScalaFunction1' object has no attribute 'hashCode'

        at py4j.Protocol.getReturnValue(Protocol.java:476)
        at py4j.reflection.PythonProxyHandler.invoke(PythonProxyHandler.java:108)
        at jdk.proxy3/jdk.proxy3.$Proxy34.hashCode(Unknown Source)
        at scala.collection.mutable.FlatHashTable.findElemImpl(FlatHashTable.scala:131)
        at scala.collection.mutable.FlatHashTable.containsElem(FlatHashTable.scala:126)
        at scala.collection.mutable.FlatHashTable.containsElem$(FlatHashTable.scala:125)
        at scala.collection.mutable.HashSet.containsElem(HashSet.scala:41)
        at scala.collection.mutable.HashSet.contains(HashSet.scala:58)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:87)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitArray(SerializationDebugger.scala:120)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:100)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:206)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visitSerializable(SerializationDebugger.scala:159)
        at org.apache.spark.serializer.SerializationDebugger$SerializationDebugger.visit(SerializationDebugger.scala:108)
        at org.apache.spark.serializer.SerializationDebugger$.find(SerializationDebugger.scala:67)
        at org.apache.spark.serializer.SerializationDebugger$.improveException(SerializationDebugger.scala:41)
        at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:49)
        at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:115)
        at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:441)
        at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:416)
        at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:163)
        at org.apache.spark.SparkContext.clean(SparkContext.scala:2674)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$1(RDD.scala:904)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
        at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
        at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:903)
        at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:759)
        at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195)
        at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191)
        at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:364)
        at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:498)
        at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
        at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
        at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4344)
        at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3326)
        at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334)
        at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
        at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
        at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
        at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
        at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332)
        at org.apache.spark.sql.Dataset.head(Dataset.scala:3326)
        at org.apache.spark.sql.Dataset.take(Dataset.scala:3549)
        at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
        at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:569)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
        at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
        at java.base/java.lang.Thread.run(Thread.java:840)
Traceback (most recent call last):
  File "/workspaces/cubie/cubie/transformers/temp.py", line 39, in <module>
    result_df.show()
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/pyspark/sql/dataframe.py", line 959, in show
    print(self._jdf.showString(n, 20, vertical))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/java_gateway.py", line 1322, in __call__
    return_value = get_return_value(
                   ^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/pyspark/errors/exceptions/captured.py", line 179, in deco
    return f(*a, **kw)
           ^^^^^^^^^^^
  File "/opt/conda/envs/dev/lib/python3.12/site-packages/py4j/protocol.py", line 326, in get_return_value
    raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling o82.showString.
: org.apache.spark.SparkException: Task not serializable
        at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:444)
        at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:416)
        at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:163)
        at org.apache.spark.SparkContext.clean(SparkContext.scala:2674)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$1(RDD.scala:904)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
        at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
        at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:903)
        at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:759)
        at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:195)
        at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:246)
        at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:243)
        at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:191)
        at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:364)
        at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:498)
        at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
        at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
        at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4344)
        at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3326)
        at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334)
        at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
        at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
        at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
        at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
        at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
        at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
        at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332)
        at org.apache.spark.sql.Dataset.head(Dataset.scala:3326)
        at org.apache.spark.sql.Dataset.take(Dataset.scala:3549)
        at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
        at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:569)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
        at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
        at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.io.NotSerializableException: py4j.reflection.PythonProxyHandler
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1187)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1381)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
        at java.base/java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1572)
        at java.base/java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1529)
        at java.base/java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1438)
        at java.base/java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1181)
        at java.base/java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:350)
        at org.apache.spark.serializer.JavaSerializationStream.writeObject(JavaSerializer.scala:46)
        at org.apache.spark.serializer.JavaSerializerInstance.serialize(JavaSerializer.scala:115)
        at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:441)
        ... 45 more

@0xbadidea I think the root of the problem is there and I do not see a way how it can be fixed because this py4j.reflection.PythonProxyHandler is not serializable at all (a good explanation in the linked py4j discussion). I tried to add scala.Serializable to list of classes that are implemented in scala_utils.ScalaFunction1 but it did not help because the problem is in py4j itself.

Thanks @SemyonSinchenko! Appreciate your help. Guess there's no straightforrward way to implement this method in pydeequ.