Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add persistent state directory to python agents #645

Merged
merged 8 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration, agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration, agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration, agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
*/
package ai.langstream.agents.grpc;

import ai.langstream.api.runner.code.AgentContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.Empty;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.net.ServerSocket;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -31,11 +33,13 @@ public class PythonGrpcServer {

private final Path codeDirectory;
private final Map<String, Object> configuration;
private final AgentContext agentContext;
private Process pythonProcess;

public PythonGrpcServer(Path codeDirectory, Map<String, Object> configuration) {
public PythonGrpcServer(Path codeDirectory, Map<String, Object> configuration, AgentContext agentContext) {
this.codeDirectory = codeDirectory;
this.configuration = configuration;
this.agentContext = agentContext;
}

public ManagedChannel start() throws Exception {
Expand All @@ -57,6 +61,8 @@ public ManagedChannel start() throws Exception {
pythonCodeDirectory.toAbsolutePath(),
pythonCodeDirectory.resolve("lib").toAbsolutePath());

AgentContextConfiguration agentContextConfiguration = computeAgentContextConfiguration();

// copy input/output to standard input/output of the java process
// this allows to use "kubectl logs" easily
ProcessBuilder processBuilder =
Expand All @@ -65,7 +71,8 @@ public ManagedChannel start() throws Exception {
"-m",
"langstream_grpc",
"[::]:%s".formatted(port),
MAPPER.writeValueAsString(configuration))
MAPPER.writeValueAsString(configuration),
MAPPER.writeValueAsString(agentContextConfiguration))
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Expand All @@ -91,6 +98,15 @@ public ManagedChannel start() throws Exception {
return channel;
}

private AgentContextConfiguration computeAgentContextConfiguration() {
final Optional<Path> persistentStateDirectoryForAgent =
agentContext.getPersistentStateDirectoryForAgent(agentContext.getAgentId());

final String persistentStateDirectory = persistentStateDirectoryForAgent.map(p -> p.toFile().getAbsolutePath()).orElse(null);
AgentContextConfiguration agentContextConfiguration = new AgentContextConfiguration(persistentStateDirectory);
return agentContextConfiguration;
}

public void close() throws Exception {
if (pythonProcess != null) {
pythonProcess.destroy();
Expand All @@ -102,4 +118,6 @@ public void close() throws Exception {
}
}
}

public record AgentContextConfiguration(String persistentStateDirectory) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@
# limitations under the License.
#

from langstream import SimpleRecord, Processor
import logging
from langstream import SimpleRecord, Processor, AgentContext
import logging, os


class Exclamation(Processor):
def init(self, config):
def init(self, config, context: AgentContext):
print("init", config)
self.secret_value = config["secret_value"]
self.context = context

def process(self, record):
logging.info("Processing record" + str(record))
directory = self.context.get_persistent_state_directory()
counter_file = os.path.resolve(directory, "counter.txt")
counter = 0
if os.path.exists(counter_file):
with open(counter_file, "r") as f:
counter = int(f.read())
counter += 1
else:
counter = 1
with open(counter_file, 'w') as file:
file.write(str(counter))


if self.secret_value == "super secret value - changed":
assert counter == 2
else:
assert counter == 1
return [
SimpleRecord(
record.value() + "!!" + self.secret_value, headers=record.headers()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[packages]
[packages]
grpcio = "*"
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
fastavro = "*"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"Source",
"Sink",
"Processor",
"AgentContext"
]


Expand Down Expand Up @@ -60,11 +61,25 @@ def headers(self) -> List[Tuple[str, Any]]:

RecordType = Union[Record, dict, list, tuple]

class AgentContext(ABC):
"""The Agent context interface"""

def __init__(self):
"""Initialize the agent context."""
pass


@abstractmethod
def get_persistent_state_directory(self):
"""Return a path pointing to the stateful agent directory. Return None if not configured in the agent."""
pass



class Agent(ABC):
"""The Agent interface"""

def init(self, config: Dict[str, Any]):
def init(self, config: Dict[str, Any], context: AgentContext):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a breaking change ?

maybe we should have same approach as in Java and have a "setAgentContext" function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it is not, I added a test for the old syntax

"""Initialize the agent from the given configuration."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
datefmt="%H:%M:%S",
)

if len(sys.argv) != 3:
if len(sys.argv) <= 3:
print("Missing gRPC target and python class name")
print("usage: python -m langstream_grpc <target> <config>")
sys.exit(1)

server = AgentServer(sys.argv[1], sys.argv[2])
context_config = {}
if len(sys.argv) > 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the process is always launched by the same version of the java runtime, there is no need to do this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

context_config = sys.argv[3]
server = AgentServer(sys.argv[1], sys.argv[2], context_config)
server.start()
server.grpc_server.wait_for_termination()
server.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"Source",
"Sink",
"Processor",
"AgentContext"
]


Expand Down Expand Up @@ -61,10 +62,23 @@ def headers(self) -> List[Tuple[str, Any]]:
RecordType = Union[Record, dict, list, tuple]


class AgentContext(ABC):
"""The Agent context interface"""

def __init__(self):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the agent context."""
pass


def get_persistent_state_directory(self):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
"""Return a path pointing to the stateful agent directory. Return None if not configured in the agent."""
return None


class Agent(ABC):
"""The Agent interface"""

def init(self, config: Dict[str, Any]):
def init(self, config: Dict[str, Any], context: AgentContext):
"""Initialize the agent from the given configuration."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import fastavro
import grpc
import inspect

from langstream_grpc.proto import agent_pb2_grpc
from langstream_grpc.proto.agent_pb2 import (
Expand All @@ -49,6 +50,7 @@
Processor,
Record,
Agent,
AgentContext
)
from .util import SimpleRecord, AvroValue

Expand Down Expand Up @@ -324,27 +326,43 @@ def to_grpc_value(self, value) -> Tuple[Optional[Schema], Optional[Value]]:
def call_method_if_exists(klass, method, *args, **kwargs):
method = getattr(klass, method, None)
if callable(method):
return method(*args, **kwargs)
defined_positional_parameters_count = len(inspect.signature(method).parameters)
if defined_positional_parameters_count >= len(args):
return method(*args, **kwargs)
else:
return method(*args[:defined_positional_parameters_count], **kwargs)
return None


def init_agent(configuration) -> Agent:
def init_agent(configuration, context) -> Agent:
full_class_name = configuration["className"]
class_name = full_class_name.split(".")[-1]
module_name = full_class_name[: -len(class_name) - 1]
module = importlib.import_module(module_name)
agent = getattr(module, class_name)()
call_method_if_exists(agent, "init", configuration)
context_impl = AgentContextImpl(configuration, context)
call_method_if_exists(agent, "init", configuration, context_impl)
return agent

class AgentContextImpl(AgentContext):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, configuration: dict, context: dict):
self.configuration = configuration
self.context = context

def get_persistent_state_directory(self) -> str:
dir = self.context.get("persistentStateDirectory", "")
if not dir:
return None
return dir


class AgentServer(object):
def __init__(self, target: str, config: str):
def __init__(self, target: str, config: str, context: str):
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
self.target = target
self.grpc_server = grpc.server(self.thread_pool)
self.port = self.grpc_server.add_insecure_port(target)
self.agent = init_agent(json.loads(config))
self.agent = init_agent(json.loads(config), json.loads(context))

def start(self):
call_method_if_exists(self.agent, "start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@

from typing import Optional

import grpc
import grpc, json

from langstream_grpc.grpc_service import AgentServer
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub


class ServerAndStub(object):
def __init__(self, class_name):
self.class_name = class_name
def __init__(self, class_name, agent_config = {}, context = {}):
self.config = agent_config.copy()
self.config["className"] = class_name
self.context = context
self.server: Optional[AgentServer] = None
self.channel: Optional[grpc.Channel] = None
self.stub: Optional[AgentServiceStub] = None

def __enter__(self):
config = f"""{{
"className": "{self.class_name}"
}}"""
self.server = AgentServer("[::]:0", config)
self.server = AgentServer("[::]:0", json.dumps(self.config), json.dumps(self.context))
self.server.start()
self.channel = grpc.insecure_channel("localhost:%d" % self.server.port)
self.stub = AgentServiceStub(channel=self.channel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
from google.protobuf import empty_pb2

from langstream_grpc.api import Record, RecordType, Processor
from langstream_grpc.api import Record, RecordType, Processor, AgentContext
from langstream_grpc.proto.agent_pb2 import (
ProcessorRequest,
Record as GrpcRecord,
Expand Down Expand Up @@ -212,6 +212,33 @@ def test_info():
assert info.json_info == '{"test-info-key": "test-info-value"}'



def test_init_one_parameter():
with ServerAndStub(
"langstream_grpc.tests.test_grpc_processor.ProcessorInitOneParameter",
{"my-param": "my-value"}
) as server_and_stub:
for response in server_and_stub.stub.process(
iter([ProcessorRequest(records=[GrpcRecord()])])
):
assert len(response.results) == 1
result = response.results[0].records[0]
assert result.value.string_value == "my-value"


def test_processor_use_context():
with ServerAndStub(
"langstream_grpc.tests.test_grpc_processor.ProcessorUseContext",
{"my-param": "my-value"},
{"persistentStateDirectory": "/tmp/processor"}
) as server_and_stub:
for response in server_and_stub.stub.process(
iter([ProcessorRequest(records=[GrpcRecord()])])
):
assert len(response.results) == 1
result = response.results[0].records[0]
assert result.value.string_value == "directory is /tmp/processor"

class MyProcessor(Processor):
def agent_info(self) -> Dict[str, Any]:
return {"test-info-key": "test-info-value"}
Expand Down Expand Up @@ -252,3 +279,24 @@ def __init__(self):

def process(self, record: Record) -> Future[List[RecordType]]:
return self.executor.submit(lambda r: [r], record)



class ProcessorInitOneParameter(Processor):
def init(self, agent_config):
self.myparam = agent_config["my-param"]

def process(self, record: Record) -> List[RecordType]:
return [{
"value": self.myparam
}]

class ProcessorUseContext(Processor):
def init(self, agent_config, context: AgentContext):
self.myparam = agent_config["my-param"]
self.context = context

def process(self, record: Record) -> List[RecordType]:
return [{
"value": "directory is " + str(self.context.get_persistent_state_directory())
}]
Loading