Skip to content

Commit

Permalink
This breaks internal tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 528602175
  • Loading branch information
eglanz authored and copybara-github committed May 1, 2023
1 parent aa102d1 commit 70afaad
Show file tree
Hide file tree
Showing 19 changed files with 98 additions and 127 deletions.
11 changes: 1 addition & 10 deletions demos/run_vizier_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from absl import flags
from absl import logging

from vizier import service
from vizier.service import servers

flags.DEFINE_string(
Expand All @@ -43,12 +42,6 @@
'Host location for the server. For distributed cases, use the IP address.',
)

flags.DEFINE_string(
'database_url',
service.SQL_LOCAL_URL,
'Location of the database for saving studies.',
)

FLAGS = flags.FLAGS

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand All @@ -58,9 +51,7 @@ def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

server = servers.DefaultVizierServer(
host=FLAGS.host, database_url=FLAGS.database_url
)
server = servers.DefaultVizierServer(host=FLAGS.host)
logging.info('Address to Vizier Server is: %s', server.endpoint)

# prevent the main thread from exiting
Expand Down
4 changes: 0 additions & 4 deletions vizier/_src/pyglove/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@

import pyglove as pg
from vizier._src.pyglove import oss_vizier
from vizier._src.service import clients as pyvizier_clients
from absl.testing import absltest


pyvizier_clients.environment_variables.servicer_use_sql_ram()


class PygloveTest(absltest.TestCase):
"""Tests for using Vizier as PyGlove backend."""

Expand Down
4 changes: 3 additions & 1 deletion vizier/_src/pyglove/oss_vizier.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def use_vizier_service(self, endpoint: Optional[None]) -> None:
pyvizier_clients.environment_variables.server_endpoint = (
self._vizier_endpoint
)
self._vizier_service = vizier_client.create_vizier_servicer_or_stub()
self._vizier_service = vizier_client.create_vizier_servicer_or_stub(
self._vizier_endpoint
)

@property
def vizier_service(self) -> vizier_types.VizierService:
Expand Down
5 changes: 1 addition & 4 deletions vizier/_src/pyglove/oss_vizier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from vizier._src.pyglove import oss_vizier as vizier
from vizier._src.pyglove import vizier_test_lib
from vizier._src.service import constants
from vizier._src.service import vizier_server

from absl.testing import absltest
Expand All @@ -31,9 +30,7 @@ class OSSVizierSampleTest(vizier_test_lib.SampleTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
server = vizier_server.DefaultVizierServer(
host=os.uname()[1], database_url=constants.SQL_MEMORY_URL
)
server = vizier_server.DefaultVizierServer(host=os.uname()[1])
logging.info(server.endpoint)
vizier._services.reset_for_testing()
vizier.init(vizier_endpoint=server.endpoint)
Expand Down
5 changes: 1 addition & 4 deletions vizier/_src/pyglove/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pyglove as pg

from vizier._src.pyglove import oss_vizier as vizier
from vizier._src.service import constants
from vizier._src.service import vizier_server

from absl.testing import absltest
Expand All @@ -39,9 +38,7 @@ class PerformanceTest(parameterized.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
server = vizier_server.DefaultVizierServer(
host=os.uname()[1], database_url=constants.SQL_MEMORY_URL
)
server = vizier_server.DefaultVizierServer(host=os.uname()[1])
logging.info(server.endpoint)
vizier._services.reset_for_testing()
vizier.init(vizier_endpoint=server.endpoint)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from __future__ import annotations

from vizier._src.service import basic_datastore
from vizier._src.service import datastore_test_lib
from vizier._src.service import ram_datastore
from vizier._src.service import vizier_service_pb2
from vizier._src.service.testing import util as test_util

Expand All @@ -30,7 +30,7 @@ def setUp(self):
self.owner_id = 'my_username'
self.study_id = '123123123'
self.client_id = 'client_0'
self.datastore = ram_datastore.NestedDictRAMDataStore()
self.datastore = basic_datastore.NestedDictRAMDataStore()
self.example_study = test_util.generate_study(self.owner_id, self.study_id)
self.example_trials = test_util.generate_trials(
[1, 2], owner_id=self.owner_id, study_id=self.study_id
Expand Down
17 changes: 15 additions & 2 deletions vizier/_src/service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,17 @@
# Redeclared so users do not have to also import client_abc and clients.py.
ResourceNotFoundError = client_abc.ResourceNotFoundError

# Redeclared since clients.py is the user-facing client API.
environment_variables = vizier_client.environment_variables

# TODO: Consider if user should set a one-line flag explicitly to
# denote local NO_ENDPOINT server will be used.
@attr.define
class _EnvironmentVariables:
server_endpoint: str = attr.field(
default=constants.NO_ENDPOINT, validator=attr.validators.instance_of(str)
)


environment_variables = _EnvironmentVariables()


@attr.define
Expand Down Expand Up @@ -203,6 +212,9 @@ def materialize_state(self) -> vz.StudyState:
@classmethod
def from_resource_name(cls: Type['Study'], name: str) -> 'Study':
client = vizier_client.VizierClient(
vizier_client.create_vizier_servicer_or_stub(
environment_variables.server_endpoint
),
name,
constants.UNUSED_CLIENT_ID,
)
Expand Down Expand Up @@ -252,6 +264,7 @@ def from_study_config(
"""
return Study(
vizier_client.create_or_load_study(
environment_variables.server_endpoint,
owner_id=owner,
client_id=constants.UNUSED_CLIENT_ID,
study_id=study_id,
Expand Down
11 changes: 2 additions & 9 deletions vizier/_src/service/clients_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@

from absl.testing import absltest

# Affects local Vizier servicer tests only.
clients.environment_variables.servicer_use_sql_ram()


class VizierClientTest(client_abc_testing.TestCase):
_owner: str
Expand Down Expand Up @@ -70,9 +67,7 @@ class VizierClientTestOnDefaultServer(VizierClientTest):
def setUpClass(cls):
logging.info('Test setup started.')
super().setUpClass()
cls._server = vizier_server.DefaultVizierServer(
database_url=constants.SQL_MEMORY_URL
)
cls._server = vizier_server.DefaultVizierServer()
clients.environment_variables.server_endpoint = cls._server.endpoint
logging.info('Test setup finished.')

Expand All @@ -89,9 +84,7 @@ class VizierClientTestOnDistributedPythiaServer(VizierClientTest):
def setUpClass(cls):
logging.info('Test setup started.')
super().setUpClass()
cls._server = vizier_server.DistributedPythiaVizierServer(
database_url=constants.SQL_MEMORY_URL
)
cls._server = vizier_server.DistributedPythiaVizierServer()
clients.environment_variables.server_endpoint = cls._server.endpoint
logging.info('Test setup finished.')

Expand Down
6 changes: 0 additions & 6 deletions vizier/_src/service/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

"""Hard coded constants used in /service/ folder."""
import os

# The metadata namespace under which the Pythia endpoint is stored.
PYTHIA_ENDPOINT_NAMESPACE = 'service'
Expand All @@ -34,8 +33,3 @@

# Will use RAM for SQL memory.
SQL_MEMORY_URL = 'sqlite:///:memory:'

# Will use local file path for HDD storage.
SERVICE_DIR = os.path.dirname(os.path.realpath(__file__))
VIZIER_DB_PATH = os.path.join(SERVICE_DIR, 'vizier.db')
SQL_LOCAL_URL = f'sqlite:///{VIZIER_DB_PATH}'
7 changes: 2 additions & 5 deletions vizier/_src/service/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import time
from absl import logging

from vizier._src.service import constants
from vizier._src.service import vizier_client
from vizier._src.service import vizier_server
from vizier.benchmarks import experimenters
Expand All @@ -36,10 +35,7 @@ class PerformanceTest(parameterized.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.server = vizier_server.DefaultVizierServer(
database_url=constants.SQL_MEMORY_URL
)
vizier_client.environment_variables.server_endpoint = cls.server.endpoint
cls.server = vizier_server.DefaultVizierServer()

@parameterized.parameters(
(1, 10, 2),
Expand All @@ -60,6 +56,7 @@ def fn(client_id: int):
study_config.algorithm = pyvizier.Algorithm.NSGA2

client = vizier_client.create_or_load_study(
server_endpoint=self.server.endpoint,
owner_id='my_username',
study_id=self.id(), # Use the testcase name.
study_config=study_config,
Expand Down
6 changes: 2 additions & 4 deletions vizier/_src/service/service_policy_supporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from vizier._src.service import constants
"""Tests for vizier.service.service_policy_supporter."""
from vizier._src.service import resources
from vizier._src.service import service_policy_supporter
from vizier._src.service import study_pb2
Expand All @@ -33,9 +33,7 @@ def setUp(self):
self.study_name = resources.StudyResource(
owner_id=self.owner_id, study_id=self.study_id
).name
self.vs = vizier_service.VizierServicer(
database_url=constants.SQL_MEMORY_URL
)
self.vs = vizier_service.VizierServicer()
self.example_study = test_util.generate_study(self.owner_id, self.study_id)
self.vs.datastore.create_study(self.example_study)

Expand Down
30 changes: 1 addition & 29 deletions vizier/_src/service/sql_datastore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from __future__ import annotations

"""Tests for sql_datastore."""
import os
import sqlalchemy as sqla

from vizier._src.service import constants
from vizier._src.service import datastore_test_lib
from vizier._src.service import sql_datastore
from vizier._src.service.testing import util as test_util
Expand Down Expand Up @@ -46,7 +44,7 @@ def setUp(self):
)
)

engine = sqla.create_engine(constants.SQL_MEMORY_URL, echo=True)
engine = sqla.create_engine('sqlite:///:memory:', echo=True)
self.datastore = sql_datastore.SQLDataStore(engine)
super().setUp()

Expand Down Expand Up @@ -78,31 +76,5 @@ def test_update_metadata(self):
)


class SQLDataStoreAdditionalTest(absltest.TestCase):
"""For additional tests outside of regular database functionality."""

def setUp(self):
super().setUp()
self.owner_id = 'my_username'
self.study_id = '123123123'
self.example_study = test_util.generate_study(self.owner_id, self.study_id)

@absltest.skip("Github workflow tests don't allow using directories.")
def test_local_hdd_persistence(self):
db_path = os.path.join(absltest.get_default_test_tmpdir(), 'local.db')
sql_url = f'sqlite:///{db_path}'

engine = sqla.create_engine(sql_url, echo=True)
datastore = sql_datastore.SQLDataStore(engine)
datastore.create_study(self.example_study)
del datastore

engine2 = sqla.create_engine(sql_url, echo=True)
datastore2 = sql_datastore.SQLDataStore(engine2)
study = datastore2.load_study(self.example_study.name)

self.assertEqual(self.example_study, study)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 70afaad

Please sign in to comment.