Skip to content

Commit

Permalink
Add python wrappers for configs and inference.
Browse files Browse the repository at this point in the history
Enable building compression/python/compression_test using bazel.
Add default image path for image_test and paligemma_test.

PiperOrigin-RevId: 718857748
  • Loading branch information
danielkeysers authored and copybara-github committed Jan 28, 2025
1 parent a248f76 commit 62aa6d0
Show file tree
Hide file tree
Showing 15 changed files with 750 additions and 13 deletions.
30 changes: 27 additions & 3 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ module(
)

bazel_dep(name = "abseil-cpp", version = "20240722.0")
bazel_dep(name = "bazel_skylib", version = "1.6.1")
bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "googletest", version = "1.15.2")
bazel_dep(name = "highway", version = "1.1.0")
bazel_dep(name = "nlohmann_json", version = "3.11.3")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "rules_cc", version = "0.0.16")
bazel_dep(name = "rules_license", version = "1.0.0")
bazel_dep(name = "rules_python", version = "1.0.0")
bazel_dep(name = "google_benchmark", version = "1.8.5")

# Require a more recent version.
Expand All @@ -23,6 +24,15 @@ git_override(

http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "com_google_absl_py",
sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758",
strip_prefix = "abseil-py-2.1.0",
urls = [
"https://github.com/abseil/abseil-py/archive/refs/tags/v2.1.0.tar.gz",
],
)

http_archive(
name = "com_google_sentencepiece",
build_file = "@//bazel:sentencepiece.bazel",
Expand Down Expand Up @@ -53,3 +63,17 @@ cc_library(
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.parse(
hub_name = "compression_deps",
python_version = "3.11",
requirements_lock = "//compression/python:requirements.txt",
)
use_repo(pip, "compression_deps")
pip.parse(
hub_name = "python_deps",
python_version = "3.11",
requirements_lock = "//python:requirements.txt",
)
use_repo(pip, "python_deps")
15 changes: 15 additions & 0 deletions compression/blob_compare.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdint>
#include <cstdio>
#include <string>
Expand Down
4 changes: 2 additions & 2 deletions compression/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ py_test(
srcs = ["compression_test.py"],
deps = [
":compression",
"//testing/pybase",
"@com_google_absl_py//absl/testing:absltest",
"//python:configs",
"//third_party/py/numpy",
"@compression_deps//numpy",
],
)
15 changes: 15 additions & 0 deletions compression/python/compression_clif_aux.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "compression/python/compression_clif_aux.h"

#include <cstddef>
Expand Down
15 changes: 15 additions & 0 deletions compression/python/compression_clif_aux.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_

Expand Down
15 changes: 15 additions & 0 deletions compression/python/compression_extension.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down
23 changes: 19 additions & 4 deletions compression/python/compression_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
# Copyright 2024 Google LLC
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for CLIF wrapped .sbs writer."""

import numpy as np

import unittest
from absl.testing import absltest
from compression.python import compression
from gemma.python import configs
from python import configs


class CompressionTest(unittest.TestCase):
class CompressionTest(absltest.TestCase):

def test_sbs_writer(self):
temp_file = self.create_tempfile("test.sbs")
Expand Down Expand Up @@ -41,4 +56,4 @@ def test_sbs_writer(self):


if __name__ == "__main__":
unittest.main()
absltest.main()
1 change: 1 addition & 0 deletions compression/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
numpy>=1.26.4
3 changes: 1 addition & 2 deletions paligemma/image_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ float Normalize(float value, float max_value = 255.0f) {
}

TEST(ImageTest, LoadResize224GetPatch) {
return; // Need to figure out how to get the external path for the test file.
std::string path;
std::string path = "paligemma/testdata/image.ppm";
Image image;
EXPECT_EQ(image.width(), 0);
EXPECT_EQ(image.height(), 0);
Expand Down
3 changes: 1 addition & 2 deletions paligemma/paligemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{

void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
ASSERT_NE(s_env->GetModel(), nullptr);
return; // Need to figure out how to get the external path for the test file.
std::string path;
std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
Expand Down
43 changes: 43 additions & 0 deletions python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# [internal] load py_binary
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = [
"//:license", # Placeholder comment, do not modify
],
default_visibility = ["//visibility:public"],
)

pybind_extension(
name = "configs",
srcs = ["configs.cc"],
deps = [
"//:common",
"//compression:sfp",
],
)

pybind_extension(
name = "gemma",
srcs = ["gemma_py.cc"],
deps = [
"//:app",
"//:benchmark_helper",
"//:gemma_lib",
"//compression:sfp",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

py_binary(
name = "run_example",
srcs = ["run_example.py"],
python_version = "PY3",
deps = [
":gemma",
"@python_deps//absl_py",
# placeholder forabsl/flags
"@compression_deps//numpy",
],
)
Loading

0 comments on commit 62aa6d0

Please sign in to comment.