java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph when using pretrained model
hortiprajwal opened this issue · 0 comments
hortiprajwal commented
Is there an existing issue for this?
- I have searched the existing issues and did not find a match.
Who can help?
No response
What are you working on?
I am trying to run spark-nlp in databricks for summarization of text.
Current Behavior
Py4JJavaError: An error occurred while calling o584.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 9.0 failed 4 times, most recent failure: Lost task 2.3 in stage 9.0 (TID 412) (172.16.2.60 executor 0): java.lang.IllegalArgumentException: No Operation named [missing_decoder_input_ids_init] in the Graph
at org.tensorflow.Graph.outputOrThrow(Graph.java:211)
at org.tensorflow.Session$Runner.feed(Session.java:248)
at com.johnsnowlabs.ml.ai.Bart.getModelOutput(Bart.scala:414)
at com.johnsnowlabs.ml.ai.util.Generation.Generate.$anonfun$beamSearch$7(Generate.scala:225)
at scala.util.control.Breaks.breakable(Breaks.scala:42)
at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch(Generate.scala:213)
at com.johnsnowlabs.ml.ai.util.Generation.Generate.beamSearch$(Generate.scala:182)
at com.johnsnowlabs.ml.ai.Bart.beamSearch(Bart.scala:40)
at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate(Generate.scala:151)
at com.johnsnowlabs.ml.ai.util.Generation.Generate.generate$(Generate.scala:85)
at com.johnsnowlabs.ml.ai.Bart.generate(Bart.scala:40)
at com.johnsnowlabs.ml.ai.Bart.tag(Bart.scala:280)
at com.johnsnowlabs.ml.ai.Bart.$anonfun$predict$1(Bart.scala:124)
at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
at com.johnsnowlabs.ml.ai.Bart.predict(Bart.scala:109)
at com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.batchAnnotate(BartTransformer.scala:324)
at com.johnsnowlabs.nlp.HasBatchedAnnotate.processBatchRows(HasBatchedAnnotate.scala:65)
at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:53)
at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$5(UnsafeRowBatchUtils.scala:88)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$3(UnsafeRowBatchUtils.scala:88)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.$anonfun$encodeUnsafeRows$1(UnsafeRowBatchUtils.scala:68)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$2(Collector.scala:197)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:82)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:82)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:62)
at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:196)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:181)
at org.apache.spark.scheduler.Task.$anonfun$run$5(Task.scala:146)
at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45)
at com.databricks.unity.HandleImpl.runWith(UCSHandle.scala:103)
at com.databricks.unity.HandleImpl.$anonfun$runWithAndClose$1(UCSHandle.scala:108)
at scala.util.Using$.resource(Using.scala:269)
at com.databricks.unity.HandleImpl.runWithAndClose(UCSHandle.scala:107)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:146)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$8(Executor.scala:900)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1709)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:903)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:798)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:750)
/
Expected Behavior
The code should print the summary and run without errors.
Steps To Reproduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, length, lit, isnull, when
from sparknlp.base import Pipeline, DocumentAssembler
from sparknlp.annotator import BartTransformer
from abc import ABCMeta, abstractmethod
from pyspark.sql import SparkSession, DataFrame
from sparknlp.base import PipelineModel
data = spark.createDataFrame(
[
(
"""Landing on another planetary body has been one of the biggest challenges and engineers across the world have been working to develop innovating technology to ensure a smooth touchdown, be it the Moon or Mars.
About 27 years ago, as a Nasa-led spacecraft was hurtling down through the thin Martian atmosphere, all eyes were glaring on the deep space network, which was looking for signs of success. Onboard was a revolutionary method that was all about a bouncy landing.
In July 1997, the spacecraft made history by successfully bouncing 15 times on the Martian surface before coming to rest, demonstrating a new and effective method for landing on Mars.
The Pathfinder mission, consisting of a lander and the small Sojourner rover, employed a groundbreaking approach to cushion its impact.""",
)
]
).toDF("content")
class SummarizationModel:
@staticmethod
def build_pipeline(input_col: str) -> Pipeline:
document_assembler = DocumentAssembler().setInputCol(input_col).setOutputCol("document")
bart_transformer = (
BartTransformer.pretrained(name="bart_large_cnn", lang="en")
.setTask("summarize:")
.setInputCols(["document"])
.setMaxOutputLength(200)
.setOutputCol("summaries")
)
pipeline = Pipeline().setStages([document_assembler, bart_transformer])
return pipeline
def process(self, data: DataFrame, input_col: str) -> DataFrame:
data = data.withColumn("content_length", when(isnull(col(input_col)), 0).otherwise(length(col(input_col))))
# generate summary for content length > 200
long_content_df = data.filter(col("content_length") > 200)
short_content_df = data.filter(col("content_length") <= 200)
short_content_df = short_content_df.withColumn("summary", lit(""))
if long_content_df.count() > 0:
pipeline = self.build_pipeline(input_col=input_col)
summarized_df = pipeline.fit(data).transform(long_content_df)
summarized_df = summarized_df.withColumn("summary", col("summaries.result")[0])
summarized_df = summarized_df.drop("document", "summaries")
data = summarized_df.union(short_content_df)
else:
data = short_content_df
return data.drop("content_length")
news_articles_df = SummarizationModel().process(data=data, input_col="content")
news_articles_df.select('summary').show(truncate=False)
Spark NLP version and Apache Spark
Spark NLP version: 5.4.0
Spark version: 3.4.1
Databricks Runtime Version: 13.3 LTS (includes Apache Spark 3.4.1, Scala 2.12)
Type of Spark Application
Python Application
Java Version
No response
Java Home Directory
No response
Setup and installation
Installed the dependency from maven and pypi under libraries in databricks
Operating System and Version
No response
Link to your project (if available)
No response
Additional Information
No response