Skip to content

Commit

Permalink
add decoder bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Dec 17, 2021
1 parent 24baa24 commit 15075a7
Show file tree
Hide file tree
Showing 6 changed files with 722 additions and 0 deletions.
111 changes: 111 additions & 0 deletions torchaudio/csrc/decoder/bindings/_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"

namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;

/**
* Some hackery that lets pybind11 handle shared_ptr<void> (for old LMStatePtr).
* See: https://github.com/pybind/pybind11/issues/820
* PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);
* and inside PYBIND11_MODULE
* py::class_<std::shared_ptr<void>>(m, "encapsulated_data");
*/

namespace {

/**
* A pybind11 "alias type" for abstract class LM, allowing one to subclass LM
* with a custom LM defined purely in Python. For those who don't want to build
* with KenLM, or have their own custom LM implementation.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html
*
* TODO: ensure this works. Last time Jeff tried this there were slicing issues,
* see https://github.com/pybind/pybind11/issues/1546 for workarounds.
* This is low-pri since we assume most people can just build with KenLM.
*/
class PyLM : public LM {
using LM::LM;

// needed for pybind11 or else it won't compile
using LMOutput = std::pair<LMStatePtr, float>;

LMStatePtr start(bool startWithNothing) override {
PYBIND11_OVERLOAD_PURE(LMStatePtr, LM, start, startWithNothing);
}

LMOutput score(const LMStatePtr& state, const int usrTokenIdx) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, score, state, usrTokenIdx);
}

LMOutput finish(const LMStatePtr& state) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, finish, state);
}
};

/**
* Using custom python LMState derived from LMState is not working with
* custom python LM (derived from PyLM) because we need to to custing of LMState
* in score and finish functions to the derived class
* (for example vie obj.__class__ = CustomPyLMSTate) which cause the error
* "TypeError: __class__ assignment: 'CustomPyLMState' deallocator differs
* from 'flashlight.text.decoder._decoder.LMState'"
* details see in https://github.com/pybind/pybind11/issues/1640
* To define custom LM you can introduce map inside LM which maps LMstate to
* additional state info (shared pointers pointing to the same underlying object
* will have the same id in python in functions score and finish)
*
* ```python
* from flashlight.lib.text.decoder import LM
* class MyPyLM(LM):
* mapping_states = dict() # store simple additional int for each state
*
* def __init__(self):
* LM.__init__(self)
*
* def start(self, start_with_nothing):
* state = LMState()
* self.mapping_states[state] = 0
* return state
*
* def score(self, state, index):
* outstate = state.child(index)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -numpy.random.random())
*
* def finish(self, state):
* outstate = state.child(-1)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -1)
*```
*/
void LexiconDecoder_decodeStep(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}

std::vector<DecodeResult> LexiconDecoder_decode(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}

} // namespace
31 changes: 31 additions & 0 deletions torchaudio/csrc/decoder/bindings/_dictionary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"

namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;

namespace {

void Dictionary_addEntry_0(
Dictionary& dict,
const std::string& entry,
int idx) {
dict.addEntry(entry, idx);
}

void Dictionary_addEntry_1(Dictionary& dict, const std::string& entry) {
dict.addEntry(entry);
}

} // namespace
19 changes: 19 additions & 0 deletions torchaudio/csrc/decoder/bindings/cmake/Buildpybind11.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
include(ExternalProject)

set(pybind11_URL https://github.com/pybind/pybind11.git)
set(pybind11_TAG 9a19306fbf30642ca331d0ec88e7da54a96860f9) # release 2.2.4

# Download pybind11
ExternalProject_Add(
pybind11
PREFIX pybind11
GIT_REPOSITORY ${pybind11_URL}
GIT_TAG ${pybind11_TAG}
BUILD_IN_SOURCE 0
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)

ExternalProject_Get_Property(pybind11 SOURCE_DIR)
set(pybind11_INCLUDE_DIR "${SOURCE_DIR}/include")
198 changes: 198 additions & 0 deletions torchaudio/csrc/decoder/bindings/cmake/FindPythonLibsNew.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# - Find python libraries
# This module finds the libraries corresponding to the Python interpreter
# FindPythonInterp provides.
# This code sets the following variables:
#
# PYTHONLIBS_FOUND - have the Python libs been found
# PYTHON_PREFIX - path to the Python installation
# PYTHON_LIBRARIES - path to the python library
# PYTHON_INCLUDE_DIRS - path to where Python.h is found
# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd'
# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string
# PYTHON_SITE_PACKAGES - path to installation site-packages
# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build
#
# Thanks to talljimbo for the patch adding the 'LDVERSION' config
# variable usage.

#=============================================================================
# Copyright 2001-2009 Kitware, Inc.
# Copyright 2012 Continuum Analytics, Inc.
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the names of Kitware, Inc., the Insight Software Consortium,
# nor the names of their contributors may be used to endorse or promote
# products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#=============================================================================

# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`.
if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION)
return()
endif()

# Use the Python interpreter to find the libs.
if(PythonLibsNew_FIND_REQUIRED)
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED)
else()
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION})
endif()

if(NOT PYTHONINTERP_FOUND)
set(PYTHONLIBS_FOUND FALSE)
return()
endif()

# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter
# testing whether sys has the gettotalrefcount function is a reliable, cross-platform
# way to detect a CPython debug interpreter.
#
# The library suffix is from the config var LDVERSION sometimes, otherwise
# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows.
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
"from distutils import sysconfig as s;import sys;import struct;
print('.'.join(str(v) for v in sys.version_info));
print(sys.prefix);
print(s.get_python_inc(plat_specific=True));
print(s.get_python_lib(plat_specific=True));
print(s.get_config_var('SO'));
print(hasattr(sys, 'gettotalrefcount')+0);
print(struct.calcsize('@P'));
print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION'));
print(s.get_config_var('LIBDIR') or '');
print(s.get_config_var('MULTIARCH') or '');
"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES
ERROR_VARIABLE _PYTHON_ERROR_VALUE)

if(NOT _PYTHON_SUCCESS MATCHES 0)
if(PythonLibsNew_FIND_REQUIRED)
message(FATAL_ERROR
"Python config failure:\n${_PYTHON_ERROR_VALUE}")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()

# Convert the process output into a list
if(WIN32)
string(REGEX REPLACE "\\\\" "/" _PYTHON_VALUES ${_PYTHON_VALUES})
endif()
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST)
list(GET _PYTHON_VALUES 1 PYTHON_PREFIX)
list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR)
list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES)
list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION)
list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG)
list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P)
list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX)
list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR)
list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH)

# Make sure the Python has the same pointer-size as the chosen compiler
# Skip if CMAKE_SIZEOF_VOID_P is not defined
if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}"))
if(PythonLibsNew_FIND_REQUIRED)
math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8")
math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8")
message(FATAL_ERROR
"Python config failure: Python is ${_PYTHON_BITS}-bit, "
"chosen compiler is ${_CMAKE_BITS}-bit")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()

# The built-in FindPython didn't always give the version numbers
string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST})
list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR)
list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR)
list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH)

# Make sure all directory separators are '/'
string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX})
string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR})
string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES})

if(CMAKE_HOST_WIN32)
set(PYTHON_LIBRARY
"${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")

# when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the
# original python installation. They may be found relative to PYTHON_INCLUDE_DIR.
if(NOT EXISTS "${PYTHON_LIBRARY}")
get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY)
set(PYTHON_LIBRARY
"${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
endif()

# raise an error if the python libs are still not found.
if(NOT EXISTS "${PYTHON_LIBRARY}")
message(FATAL_ERROR "Python libraries not found")
endif()

else()
if(PYTHON_MULTIARCH)
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}")
else()
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}")
endif()
#message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}")
# Probably this needs to be more involved. It would be nice if the config
# information the python interpreter itself gave us were more complete.
find_library(PYTHON_LIBRARY
NAMES "python${PYTHON_LIBRARY_SUFFIX}"
PATHS ${_PYTHON_LIBS_SEARCH}
NO_DEFAULT_PATH)

# If all else fails, just set the name/version and let the linker figure out the path.
if(NOT PYTHON_LIBRARY)
set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX})
endif()
endif()

MARK_AS_ADVANCED(
PYTHON_LIBRARY
PYTHON_INCLUDE_DIR
)

# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the
# cache entries because they are meant to specify the location of a single
# library. We now set the variables listed by the documentation for this
# module.
SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}")
SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}")
SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}")

find_package_message(PYTHON
"Found PythonLibs: ${PYTHON_LIBRARY}"
"${PYTHON_EXECUTABLE}${PYTHON_VERSION}")

set(PYTHONLIBS_FOUND TRUE)
Loading

0 comments on commit 15075a7

Please sign in to comment.