JohnSnowLabs/spark-nlp

java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph when using pretrained model

hortiprajwal opened this issue · 0 comments

Is there an existing issue for this?

  • I have searched the existing issues and did not find a match.

Who can help?

No response

What are you working on?

I am trying to run spark-nlp in databricks for summarization of text.

Current Behavior

Py4JJavaError: An error occurred while calling o584.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 9.0 failed 4 times, most recent failure: Lost task 2.3 in stage 9.0 (TID 412) (172.16.2.60 executor 0): java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph
	at org.tensorflow.Graph.outputOrThrow(Graph.java:211)
	at org.tensorflow.Session$Runner.feed(Session.java:248)
	at com.johnsnowlabs.ml.ai.Bart.getModelOutput(Bart.scala:414)
	at com.johnsnowlabs.ml.ai.util.Generation.Generate.$anonfun$beamSearch$7(Generate.scala:225)
	at scala.util.control.Breaks.breakable(Breaks.scala:42)
	at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch(Generate.scala:213)
	at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch$(Generate.scala:182)
	at com.johnsnowlabs.ml.ai.Bart.beamSearch(Bart.scala:40)
	at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate(Generate.scala:151)
	at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate$(Generate.scala:85)
	at com.johnsnowlabs.ml.ai.Bart.generate(Bart.scala:40)
	at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:280)
	at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:124)
	at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
	at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
	at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
	at com.johnsnowlabs.ml.ai.Bart.predict(Bart.scala:109)
	at com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.batchAnnotate(BartTransformer.scala:324)
	at com.johnsnowlabs.nlp.HasBatchedAnnotate.processBatchRows(HasBatchedAnnotate.scala:65)
	at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:53)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$5(UnsafeRowBatchUtils.scala:88)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$3(UnsafeRowBatchUtils.scala:88)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$1(UnsafeRowBatchUtils.scala:68)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
	at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$2(Collector.scala:197)
	at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:82)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:82)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:62)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:196)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:181)
	at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:146)
	at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45)
	at com.databricks.unity.HandleImpl.runWith(UCSHandle.scala:103)
	at com.databricks.unity.HandleImpl.$anonfun$runWithAndClose$1(UCSHandle.scala:108)
	at scala.util.Using$.resource(Using.scala:269)
	at com.databricks.unity.HandleImpl.runWithAndClose(UCSHandle.scala:107)
	at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:146)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$8(Executor.scala:900)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1709)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:903)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:798)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)
/

Expected Behavior

The code should print the summary and run without errors.

Steps To Reproduce

from pyspark.sql import DataFrame
from pyspark.sql.functions import col, length, lit, isnull, when
from sparknlp.base import Pipeline, DocumentAssembler
from sparknlp.annotator import BartTransformer

from abc import ABCMeta, abstractmethod
from pyspark.sql import SparkSession, DataFrame
from sparknlp.base import PipelineModel


data = spark.createDataFrame(
    [
        (
            """Landing on another planetary body has been one of the biggest challenges and engineers across the world have been working to develop innovating technology to ensure a smooth touchdown, be it the Moon or Mars.

About 27 years ago, as a Nasa-led spacecraft was hurtling down through the thin Martian atmosphere, all eyes were glaring on the deep space network, which was looking for signs of success. Onboard was a revolutionary method that was all about a bouncy landing.

In July 1997, the spacecraft made history by successfully bouncing 15 times on the Martian surface before coming to rest, demonstrating a new and effective method for landing on Mars.

The Pathfinder mission, consisting of a lander and the small Sojourner rover, employed a groundbreaking approach to cushion its impact.""",
        )
    ]
).toDF("content")


class SummarizationModel:
    @staticmethod
    def build_pipeline(input_col: str) -> Pipeline:
        document_assembler = DocumentAssembler().setInputCol(input_col).setOutputCol("document")

        bart_transformer = (
            BartTransformer.pretrained(name="bart_large_cnn", lang="en")
            .setTask("summarize:")
            .setInputCols(["document"])
            .setMaxOutputLength(200)
            .setOutputCol("summaries")
        )

        pipeline = Pipeline().setStages([document_assembler, bart_transformer])
        return pipeline

    def process(self, data: DataFrame, input_col: str) -> DataFrame:
        data = data.withColumn("content_length", when(isnull(col(input_col)), 0).otherwise(length(col(input_col))))

        # generate summary for content length > 200
        long_content_df = data.filter(col("content_length") > 200)
        short_content_df = data.filter(col("content_length") <= 200)
        short_content_df = short_content_df.withColumn("summary", lit(""))

        if long_content_df.count() > 0:
            pipeline = self.build_pipeline(input_col=input_col)
            summarized_df = pipeline.fit(data).transform(long_content_df)
            summarized_df = summarized_df.withColumn("summary", col("summaries.result")[0])
            summarized_df = summarized_df.drop("document", "summaries")
            data = summarized_df.union(short_content_df)
        else:
            data = short_content_df

        return data.drop("content_length")
    

news_articles_df = SummarizationModel().process(data=data, input_col="content")
news_articles_df.select('summary').show(truncate=False)

Spark NLP version and Apache Spark

Spark NLP version: 5.4.0
Spark version: 3.4.1

Databricks Runtime Version: 13.3 LTS (includes Apache Spark 3.4.1, Scala 2.12)

Type of Spark Application

Python Application

Java Version

No response

Java Home Directory

No response

Setup and installation

Installed the dependency from maven and pypi under libraries in databricks

Operating System and Version

No response

Link to your project (if available)

No response

Additional Information

No response