Skip to content

Commit

Permalink
Adapt to train-test split
Browse files Browse the repository at this point in the history
  • Loading branch information
gmontamat committed Sep 23, 2024
1 parent d06e3e9 commit 196cb8b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions examples/sample_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# docker run -d --rm --name gentun-redis -p 6379:6379 redis
worker = RedisWorker("test", Dummy, host="localhost", port=6379)

x_train = []
y_train = []
x_train, y_train = [], []
x_test, y_test = [], []

# Start worker process
worker.run(x_train, y_train)
worker.run(x_train, y_train, x_test, y_test)
12 changes: 6 additions & 6 deletions src/gentun/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from gentun.services import RedisWorker
worker = RedisWorker("{name}", {handler}, host="{host}", port={port})
x_train, y_train = ... # get data
worker.run(x_train, y_train)
x_train, y_train, x_test, y_test = ... # get data
worker.run(x_train, y_train, x_test, y_test)
```
"""

Expand Down Expand Up @@ -116,11 +116,11 @@ def __init__(
self.results_queue = results_queue
self.timeout = timeout

def process_job(self, x_train: Any, y_train: Any, **kwargs) -> float:
def process_job(self, x_train: Any, y_train: Any, x_test: Any, y_test: Any, **kwargs) -> float:
"""Call model handler, return fitness."""
return self.handler(**kwargs).evaluate(x_train, y_train)
return self.handler(**kwargs)(x_train, y_train)

def run(self, x_train: Any, y_train: Any):
def run(self, x_train: Any, y_train: Any, x_test: Any = None, y_test: Any = None):
"""Read jobs from queue, call handler, and return fitness."""
logging.info("Worker started (Ctrl+C to stop), waiting for jobs...")
try:
Expand All @@ -130,7 +130,7 @@ def run(self, x_train: Any, y_train: Any):
data = json.loads(job_data)
if data["name"] == self.name and data["handler"] == self.handler.__name__:
logging.info("Working on job %s", data["id"])
fitness = self.process_job(x_train, y_train, **data["kwargs"])
fitness = self.process_job(x_train, y_train, x_test, y_test, **data["kwargs"])
result = {"id": data["id"], "name": self.name, "fitness": fitness}
self.client.rpush(self.results_queue, json.dumps(result))
else:
Expand Down

0 comments on commit 196cb8b

Please sign in to comment.