Skip to content

Commit

Permalink
changed the way we handle question lable
Browse files Browse the repository at this point in the history
  • Loading branch information
Meetatgoogle committed Nov 26, 2024
1 parent 536946c commit e00568b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
7 changes: 5 additions & 2 deletions py/sight/demo/fn_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,23 @@ def warn(*args, **kwargs):

FLAGS = flags.FLAGS

def get_question_label():
return 'Q_label1'

# Define the black box function to optimize.
def black_box_function(args):
return sum(xi**2 for xi in args)


def driver(question_label,sight: Sight) -> None:
def driver(sight: Sight) -> None:
"""Executes the logic of searching for a value.
Args:
sight: The Sight logger object used to drive decisions.
"""

for _ in range(1):
next_point = decision.decision_point(question_label, sight)
next_point = decision.decision_point(get_question_label(), sight)
print('next_point : ', next_point)
reward = black_box_function(list(next_point.values()))
print('reward : ', reward)
Expand Down Expand Up @@ -96,6 +98,7 @@ def main(argv: Sequence[str]) -> None:
driver_fn=driver,
action_attrs=action_attrs,
sight=sight,
question_label=get_question_label,
)


Expand Down
13 changes: 7 additions & 6 deletions py/sight/widgets/decision/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def run(
outcome_attrs: Dict[str,
sight_pb2.DecisionConfigurationStart.AttrProps] = {},
description: str = '',
question_label: Callable[[], str]= None
):
"""Driver for running applications that use the Decision API.
Expand Down Expand Up @@ -880,14 +881,14 @@ def run(
# if(FLAGS.optimizer_type == "worklist_scheduler"):
# if (FLAGS.deployment_mode == 'worker_mode'):
# import os
optimizer_configs = utils.load_yaml_config('/x-sight/fvs_sight/optimizer_config.yaml')

for key in optimizer_configs.keys():
# optimizer_configs = utils.load_yaml_config('/x-sight/fvs_sight/optimizer_config.yaml')
# for key in optimizer_configs.keys():
while (True):
# #? new rpc just to check move forward or not?
req = service_pb2.WorkerAliveRequest(
client_id=client_id,
question_label=key,
question_label=question_label(),
worker_id=f'client_{client_id}_worker_{worker_location}'
)
response = service.call(
Expand All @@ -914,9 +915,9 @@ def run(
if env:
driver_fn(env, sight)
else:
driver_fn(key, sight)
driver_fn(sight)

finalize_episode(key, sight)
finalize_episode(question_label, sight)
sight.exit_block('Decision Sample', sight_pb2.Object())
else:
raise ValueError("invalid response from server")
Expand Down Expand Up @@ -1313,7 +1314,7 @@ def finalize_episode(question_label, sight): # , optimizer_obj
req = service_pb2.FinalizeEpisodeRequest(
client_id=client_id,
worker_id=f'client_{client_id}_worker_{worker_location}',
question_label=question_label
question_label=question_label()
)

if _OPTIMIZER_TYPE.value in [
Expand Down

0 comments on commit e00568b

Please sign in to comment.