Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update usage of deprecated API (get_weights) in TensorFlow 2 Optimizer. #7319

Open
lalalapotter opened this issue Jan 19, 2023 · 2 comments
Open
Assignees
Labels

Comments

@lalalapotter
Copy link
Contributor

Encountered following error when I run LSTM model with orca pyspark tf2 estimator:

Traceback (most recent call last):
  File "lstm_pollution.py", line 33, in <module>
    steps_per_epoch= df.count() // batch_size)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/bigdl/orca/learn/tf2/pyspark_estimator.py", line 199, in fit
    lambda iter: transform_func(iter, init_params, params)).collect()
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/pyspark/rdd.py", line 816, in collect
    sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(12, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/pyspark/rdd.py", line 2465, in func
    return f(iterator)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/bigdl/orca/learn/tf2/pyspark_estimator.py", line 199, in <lambda>
    lambda iter: transform_func(iter, init_params, params)).collect()
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/bigdl/orca/learn/tf2/pyspark_estimator.py", line 196, in transform_func
    return SparkRunner(**init_param).step(**param)
  File "/home/cengguang/anaconda3/envs/bigdl2.1/lib/python3.7/site-packages/bigdl/orca/learn/tf2/spark_runner.py", line 349, in step
    "optimizer_weights": model.optimizer.get_weights()
AttributeError: 'Adam' object has no attribute 'get_weights'

I use tensorflow 2.11 and estimator.fit() API with argument model_dir specified. Seems that the program use deprecated API to get weights from optimizer. Please check if we need to update the usage.

@lalalapotter lalalapotter changed the title Update usage of Depracate API (get_weights) in TensorFlow 2 Optimizer. Update usage of Deprecated API (get_weights) in TensorFlow 2 Optimizer. Jan 19, 2023
@lalalapotter lalalapotter changed the title Update usage of Deprecated API (get_weights) in TensorFlow 2 Optimizer. Update usage of deprecated API (get_weights) in TensorFlow 2 Optimizer. Jan 19, 2023
@sgwhat
Copy link
Contributor

sgwhat commented Feb 2, 2023

Caused by a break change (in version 2.11.0) by Tensorflow team, still looking into a general solution.

@sgwhat
Copy link
Contributor

sgwhat commented Feb 3, 2023

keras-team/tf-keras#442 maybe a useful solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants