Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 31, 2024
1 parent 27aa6cf commit ce6c3d3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transcribe_batch(self,
orig_freq=sample_rate,
new_freq=self.resample_rate)(waveform)

waveform = waveform.to(torch.float).to(self.device)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
Expand All @@ -53,7 +53,7 @@ def transcribe_batch(self,
torch.tensor(feats.shape[0], dtype=torch.int64))
feats_tensor = torch.nn.utils.rnn.pad_sequence(
feats_lst, batch_first=True).to(device=self.device)
feats_lens_tensor = torch.tensor(feats_lens_lst)
feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device)

decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
feats_tensor, feats_lens_tensor)
Expand Down

0 comments on commit ce6c3d3

Please sign in to comment.