Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: cannot find context for 'fork' when processor_with_lm.batch_decode(_logits) #16898

Closed
2 of 4 tasks
elsheikh21 opened this issue Apr 22, 2022 · 15 comments
Closed
2 of 4 tasks
Labels

Comments

@elsheikh21
Copy link

elsheikh21 commented Apr 22, 2022

System Info

## Environment info
- `transformers` version: 4.17.0
- Platform: Windows-10-10.0.22000-SP0
- Python version: 3.8.13
- PyTorch version (GPU?): 1.9.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No

Who can help?

@patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

To reproduce

  • The model I am using (Wav2Vec2.0 Large XLS-R 53 English):

  • Steps to reproduce the behavior:

  1. I am fine-tuning Wav2Vec with LM Head using WikiText to produce 5-grams LM. I downloaded the fine-tuned model dir locally and was able to perform inference on my audio .wav file(s)
  2. Please find here, model files, test audio file, and requirements.txt if needed to reproduce the problem

Code snippet

import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
from datasets import load_dataset
import soundfile as sf
  

model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor_path = path_join(getcwd(), "stt_assets", "stt_model")
processor = Wav2Vec2ProcessorWithLM.from_pretrained(processor_path)
  
dataset = load_dataset("timit_asr", split="test").shuffle().shuffle().select(range(100))
char_translations = str.maketrans({"-": " ", ",": "", ".": "", "?": ""})


def prepare_example(example):
    example["speech"], _ = sf.read(example["file"])
    example["text"] = example["text"].translate(char_translations)
    example["text"] = " ".join(example["text"].split())  # clean up whitespace
    example["text"] = example["text"].lower()
    return example
  

dataset = dataset.map(prepare_example, remove_columns=["file"])
  
pprint(dataset)
features = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(**features).logits

# logits shape is torch.Size([100, 304, 33])
transcription = processor.batch_decode(logits)
# EXCEPTION IS RAISED in `processor.batch_decode()` ValueError: cannot find context for 'fork'
print(transcription)

Expected behavior

What I am expecting is that I get a list of transcriptions from `processor.batch_decode()` 

but I get this `ValueError: cannot find context for 'fork'` Exception. I am using Windows 11, 

I have tried to research it and I guess it is something related to multiprocessing but I could 
not really figure out how to solve it yet
@elsheikh21 elsheikh21 added the bug label Apr 22, 2022
@patrickvonplaten
Copy link
Contributor

Related woven-planet/l5kit#129

@patrickvonplaten
Copy link
Contributor

Hey @elsheikh21,

Let's try to narrow the bug further down :-)

Does the following work for you:

from multiprocessing import get_context
pool = get_context("fork").Pool(num_processes)
pool.close()

?

@elsheikh21
Copy link
Author

elsheikh21 commented Apr 25, 2022

Hello @patrickvonplaten

I have tried to run

from multiprocessing import get_context
num_processes = 8
pool = get_context("fork").Pool(num_processes)
pool.close()

and got the following traceback

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\AhmedElSheikh\AppData\Local\Programs\Python\Python38\lib\multiprocessing\context.py", line 239, in get_context
    return super().get_context(method)
  File "C:\Users\AhmedElSheikh\AppData\Local\Programs\Python\Python38\lib\multiprocessing\context.py", line 193, in get_context
    raise ValueError('cannot find context for %r' % method) from None
ValueError: cannot find context for 'fork'

System Information
Windows 11
Python 3.8.10

@elsheikh21
Copy link
Author

Related woven-planet/l5kit#129

I have read this thread, yet the error itself occurs when I call processor.batch_decode and I am working on the project not just to be used on my local device only

@patrickvonplaten
Copy link
Contributor

Hello @patrickvonplaten

I have tried to run

from multiprocessing import get_context
num_processes = 8
pool = get_context("fork").Pool(num_processes)
pool.close()

and got the following traceback

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\AhmedElSheikh\AppData\Local\Programs\Python\Python38\lib\multiprocessing\context.py", line 239, in get_context
    return super().get_context(method)
  File "C:\Users\AhmedElSheikh\AppData\Local\Programs\Python\Python38\lib\multiprocessing\context.py", line 193, in get_context
    raise ValueError('cannot find context for %r' % method) from None
ValueError: cannot find context for 'fork'

System Information Windows 11 Python 3.8.10

This seems to be the error then.

Could you try to replace "fork" with "spawn"?

@patrickvonplaten
Copy link
Contributor

If "spawn" works then it might make most sense to just update "fork" to "spawn"

@elsheikh21
Copy link
Author

I have tried to run with "spawn" and it works fine, but in that case I will need to change the file transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.py and I guess that wont work when I run the same code on another machine, is there a way to force "spawn" when "fork" does not work?

@patrickvonplaten
Copy link
Contributor

Think we can just replace "fork" with "spawn" - do you want to open a PR to fix it? :-)

@elsheikh21
Copy link
Author

elsheikh21 commented Apr 25, 2022

Think we can just replace "fork" with "spawn" - do you want to open a PR to fix it? :-)

Yes, I would happily do that, I guess it would be something along those lines? please feel free to modify my approach. Otherwise I will start reading about collaborating and how to open PR

try:
    pool = get_context("fork").Pool(num_processes)
except ValueError as exc:
    if "cannot find context for 'fork'" in exc:
         pool = get_context("spawn").Pool(num_processes)
         logging.info("Switching to \"spawn\" as \"fork\" context is not found")

@patrickvonplaten
Copy link
Contributor

I think we can actually just change "fork" to "spawn" (no need for a try, ... expect IMO). According to https://stackoverflow.com/questions/64095876/multiprocessing-fork-vs-spawn and some other docs, "spawn" is safe and given that the child process is LM-boosted decoding (which is always slow), doing the switch should be fine

@elsheikh21
Copy link
Author

I think we can actually just change "fork" to "spawn" (no need for a try, ... expect IMO). According to https://stackoverflow.com/questions/64095876/multiprocessing-fork-vs-spawn and some other docs, "spawn" is safe and given that the child process is LM-boosted decoding (which is always slow), doing the switch should be fine

Okay let us do it your way then, I have also created a custom dataset loader (from flac/wav audio files) and model finetuner, evaluator if those can be helpful for the community I would love to share them as well

For now I will open a PR for spawn and fork

@ADD-eNavarro
Copy link

Exactly same problem here, also trying to run this under Windows 10 and getting the same error, when in processing_wav2vec2_with_lm.py, line 316, gets "fork" from context.
But since I see it's already being fixed, I'll just thank and wait 👍

@elsheikh21
Copy link
Author

elsheikh21 commented Jun 15, 2022

Exactly same problem here, also trying to run this under Windows 10 and getting the same error, when in processing_wav2vec2_with_lm.py, line 316, gets "fork" from context. But since I see it's already being fixed, I'll just thank and wait 👍

as a quick fix you can replace "fork" with "spawn" in the line pool = get_context("fork").Pool(num_processes), file transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.py

@patrickvonplaten
Copy link
Contributor

@ADD-eNavarro @elsheikh21 sorry I don't work with Windows usually and am a bit buried with other issues. Regarding the PR please lemme know if anything isn't clear, happy trying to be more precise - in short I think we should try to apply the exact same solution that was applied in pyctcdecode

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants