Skip to content

Commit

Permalink
add "pymarian" CLI, a proxy to "marian" binary
Browse files Browse the repository at this point in the history
  • Loading branch information
Thamme Gowda committed Aug 9, 2024
1 parent a6ab8af commit 58e246f
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fixed compilation with clang 16.0.6
- Added Threads::Threads to `EXT_LIBS`
- Updates to pymarian: building for multiple python versions; disabling tcmalloc; hosting gated COMETs on HuggingFace
- Add "pymarian" CLI, a proxy to "marian" binary, but made available in PATH after "pip install pymarian"

### Added
- Added `--normalize-gradient-by-ratio` to mildly adapt gradient magnitude if effective batch size diverges from running average effective batch size.
Expand Down
48 changes: 37 additions & 11 deletions src/command/marian_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,51 @@
#include "marian_conv.cpp"
#undef main

#include <string>
#include <map>
#include <tuple>
#include "3rd_party/ExceptionWithCallStack.h"
#include "3rd_party/spdlog/details/format.h"

int main(int argc, char** argv) {
using namespace marian;
using MainFunc = int(*)(int, char**);
std::map<std::string, std::tuple<MainFunc, std::string>> subcmds = {
{"train", {&mainTrainer, "Train a model (default)"}},
{"decode", {&mainDecoder, "Decode or translate text"}},
{"score", {&mainScorer, "Score translations"}},
{"embed", {&mainEmbedder, "Embed text"}},
{"evaluate", {&mainEvaluator, "Run Evaluator metric"}},
{"vocab", {&mainVocab, "Create vocabulary"}},
{"convert", {&mainConv, "Convert model file format"}}
};
// no arguments, or the first arg is "?"", print help message
if (argc == 1 || (argc == 2 && (std::string(argv[1]) == "?") )) {
std::cout << "Usage: " << argv[0] << " COMMAND [ARGS]" << std::endl;
std::cout << "Commands:" << std::endl;
for (auto&& [name, val] : subcmds) {
std::cerr << fmt::format("{:10} : {}\n", name, std::get<1>(val));
}
return 0;
}

if(argc > 1 && argv[1][0] != '-') {
if (argc > 1 && argv[1][0] != '-') {
std::string cmd = argv[1];
argc--;
argv[1] = argv[0];
argv++;
if(cmd == "train") return mainTrainer(argc, argv);
else if(cmd == "decode") return mainDecoder(argc, argv);
else if (cmd == "score") return mainScorer(argc, argv);
else if (cmd == "embed") return mainEmbedder(argc, argv);
else if (cmd == "evaluate") return mainEvaluator(argc, argv);
else if (cmd == "vocab") return mainVocab(argc, argv);
else if (cmd == "convert") return mainConv(argc, argv);
std::cerr << "Command must be train, decode, score, embed, vocab, or convert." << std::endl;
exit(1);
} else
if (subcmds.count(cmd) > 0) {
auto [func, desc] = subcmds[cmd];
return func(argc, argv);
}
else {
std::cerr << "Unknown command: " << cmd << ". Known commands are:" << std::endl;
for (auto&& [name, val] : subcmds) {
std::cerr << fmt::format("{:10} : {}\n", name, std::get<1>(val));
}
return 1;
}
}
else
return mainTrainer(argc, argv);
}
27 changes: 24 additions & 3 deletions src/python/binding/bind.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#define PYBIND11_DETAILED_ERROR_MESSAGES

#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
// if your IDE/vscode complains about missing paths
Expand All @@ -6,13 +8,30 @@
#include "evaluator.hpp"
#include "trainer.hpp"
#include "translator.hpp"


#define PYBIND11_DETAILED_ERROR_MESSAGES
#include "command/marian_main.cpp"

namespace py = pybind11;
using namespace pymarian;

/**
* @brief Wrapper function to call Marian main entry point from Python
*
* Calls Marian main entry point from Python.
* It converts args from a vector of strings (Python-ic API) to char* (C API)
* before passsing on to the main function.
* @param args vector of strings
* @return int return code
*/
int main_wrap(std::vector<std::string> args) {
// Convert vector of strings to vector of char*
std::vector<char*> argv;
argv.push_back(const_cast<char*>("pymarian"));
for (auto& arg : args) {
argv.push_back(const_cast<char*>(arg.c_str()));
}
argv.push_back(nullptr);
return main(argv.size() - 1, argv.data());
}

PYBIND11_MODULE(_pymarian, m) {
m.doc() = "Marian C++ API bindings via pybind11";
Expand Down Expand Up @@ -44,5 +63,7 @@ PYBIND11_MODULE(_pymarian, m) {
.def("embed", py::overload_cast<>(&PyEmbedder::embed))
;

m.def("main", &main_wrap, "Marian main entry point");

}

27 changes: 22 additions & 5 deletions src/python/pymarian/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from itertools import islice
from pathlib import Path
import sys
from typing import Iterator, List, Optional, Tuple, Union

import _pymarian
Expand Down Expand Up @@ -46,8 +47,8 @@ def model_type(self) -> str:
@classmethod
def new(
cls,
model_file: Path,
vocab_file: Path = None,
model_file: Union[Path, str],
vocab_file: Union[Path, str] = None,
devices: Optional[List[int]] = None,
width=Defaults.FLOAT_PRECISION,
mini_batch=Defaults.MINI_BATCH,
Expand Down Expand Up @@ -76,8 +77,8 @@ def new(
:return: iterator of scores
"""

assert model_file.exists(), f'Model file {model_file} does not exist'
assert vocab_file.exists(), f'Vocab file {vocab_file} does not exist'
assert Path(model_file).exists(), f'Model file {model_file} does not exist'
assert Path(vocab_file).exists(), f'Vocab file {vocab_file} does not exist'
assert like in Defaults.MODEL_TYPES, f'Unknown model type: {like}'
n_inputs = len(Defaults.MODEL_TYPES[like])
vocabs = [vocab_file] * n_inputs
Expand All @@ -97,7 +98,7 @@ def new(
cpu_threads=cpu_threads,
average=average,
)
if kwargs.pop('fp16'):
if kwargs.pop('fp16', False):
kwargs['fp16'] = '' # empty string for flag; i.e, "--fp16" and not "--fp16=true"

# TODO: remove this when c++ bindings supports iterator
Expand Down Expand Up @@ -171,3 +172,19 @@ def __init__(self, cli_string='', **kwargs):
"""
cli_string += ' ' + kwargs_to_cli(**kwargs)
super().__init__(cli_string.stip())

def main():
"""proxy to marian main function"""
code = _pymarian.main(sys.argv[1:])
sys.exit(code)

def help(*vargs):
"""print help text"""
args = []
args += vargs
if '--help' not in args and '-h' not in args:
args.append('--help')
# note: this will print to stdout
_pymarian.main(args)
# do not exit, as this is a library function

1 change: 1 addition & 0 deletions src/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
]

[project.scripts]
pymarian = "pymarian:main"
pymarian-eval = "pymarian.eval:main"
pymarian-qtdemo = "pymarian.qtdemo:main"
pymarian-mtapi = "pymarian.mtapi_server:main"
Expand Down

0 comments on commit 58e246f

Please sign in to comment.