Skip to content

Commit

Permalink
Merge branch 'smalton/fix-cpu-duplex' into 'master'
Browse files Browse the repository at this point in the history
Fix cpu duplex

See merge request machine-learning/dorado!461
  • Loading branch information
iiSeymour committed Jul 3, 2023
2 parents f5ccd0d + 8ecfe6b commit d2700dd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dorado/nn/ModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ ModelRunner<T>::ModelRunner(const CRFModelConfig &model_config,
// adjust chunk size to be a multiple of the stride
chunk_size -= chunk_size % m_model_stride;

m_input = torch::zeros({batch_size, 1, chunk_size},
m_input = torch::zeros({batch_size, model_config.num_features, chunk_size},
torch::TensorOptions().dtype(T::dtype).device(torch::kCPU));
}

Expand All @@ -96,7 +96,7 @@ std::vector<DecodedChunk> ModelRunner<T>::call_chunks(int num_chunks) {

template <typename T>
void ModelRunner<T>::accept_chunk(int chunk_idx, const torch::Tensor &chunk) {
m_input.index_put_({chunk_idx, 0}, chunk);
m_input.index_put_({chunk_idx, torch::indexing::Ellipsis}, chunk);
}

template <typename T>
Expand Down

0 comments on commit d2700dd

Please sign in to comment.