Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add persistent state directory to python agents #645

Merged
merged 8 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server =
new PythonGrpcServer(
agentContext.getCodeDirectory(), configuration, agentId(), agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server =
new PythonGrpcServer(
agentContext.getCodeDirectory(), configuration, agentId(), agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public void init(Map<String, Object> configuration) throws Exception {

@Override
public void start() throws Exception {
server = new PythonGrpcServer(agentContext.getCodeDirectory(), configuration);
server =
new PythonGrpcServer(
agentContext.getCodeDirectory(), configuration, agentId(), agentContext);
channel = server.start();
super.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
*/
package ai.langstream.agents.grpc;

import ai.langstream.api.runner.code.AgentContext;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.Empty;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.net.ServerSocket;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -31,11 +33,19 @@ public class PythonGrpcServer {

private final Path codeDirectory;
private final Map<String, Object> configuration;
private final String agentId;
private final AgentContext agentContext;
private Process pythonProcess;

public PythonGrpcServer(Path codeDirectory, Map<String, Object> configuration) {
public PythonGrpcServer(
Path codeDirectory,
Map<String, Object> configuration,
String agentId,
AgentContext agentContext) {
this.codeDirectory = codeDirectory;
this.configuration = configuration;
this.agentId = agentId;
this.agentContext = agentContext;
}

public ManagedChannel start() throws Exception {
Expand All @@ -57,6 +67,8 @@ public ManagedChannel start() throws Exception {
pythonCodeDirectory.toAbsolutePath(),
pythonCodeDirectory.resolve("lib").toAbsolutePath());

AgentContextConfiguration agentContextConfiguration = computeAgentContextConfiguration();

// copy input/output to standard input/output of the java process
// this allows to use "kubectl logs" easily
ProcessBuilder processBuilder =
Expand All @@ -65,7 +77,8 @@ public ManagedChannel start() throws Exception {
"-m",
"langstream_grpc",
"[::]:%s".formatted(port),
MAPPER.writeValueAsString(configuration))
MAPPER.writeValueAsString(configuration),
MAPPER.writeValueAsString(agentContextConfiguration))
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Expand All @@ -91,6 +104,19 @@ public ManagedChannel start() throws Exception {
return channel;
}

private AgentContextConfiguration computeAgentContextConfiguration() {
final Optional<Path> persistentStateDirectoryForAgent =
agentContext.getPersistentStateDirectoryForAgent(agentId);

final String persistentStateDirectory =
persistentStateDirectoryForAgent
.map(p -> p.toFile().getAbsolutePath())
.orElse(null);
AgentContextConfiguration agentContextConfiguration =
new AgentContextConfiguration(persistentStateDirectory);
return agentContextConfiguration;
}

public void close() throws Exception {
if (pythonProcess != null) {
pythonProcess.destroy();
Expand All @@ -102,4 +128,6 @@ public void close() throws Exception {
}
}
}

public record AgentContextConfiguration(String persistentStateDirectory) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ public static KubernetesClient getClient() {
}

private static void dumpTest(String prefix) {
dumpAllPodsLogs(prefix);
dumpAllPodsLogs(prefix + ".logs");
dumpEvents(prefix);
dumpAllResources(prefix);
dumpAllResources(prefix + ".resource");
dumpProcessOutput(prefix, "kubectl-nodes", "kubectl describe nodes".split(" "));
}

Expand Down Expand Up @@ -558,7 +558,8 @@ private void cleanupEnv() {

@AfterEach
public void cleanupAfterEach() {
cleanupAllEndToEndTestsNamespaces();
// do not cleanup langstream tenant here otherwise we won't get the logs in case of test
// failed
cleanupEnv();
}

Expand Down Expand Up @@ -688,6 +689,8 @@ private static void installLangStream(boolean authentication) {
agentResources:
cpuPerUnit: %s
memPerUnit: %s
storageClassesMapping:
default: standard
client:
image:
repository: %s/langstream-cli
Expand Down Expand Up @@ -941,7 +944,7 @@ protected static void dumpResource(String filePrefix, HasMetadata resource) {
final File outputFile =
new File(
TEST_LOGS_DIR,
"%s-%s-%s.txt"
"%s.%s.%s.txt"
.formatted(
filePrefix,
resource.getKind(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pipeline:
- name: "Process using Python"
resources:
size: 2
disk:
enabled: true
size: 50M
id: "test-python-processor"
type: "python-processor"
input: ls-test-topic0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@
# limitations under the License.
#

from langstream import SimpleRecord, Processor
import logging
from langstream import SimpleRecord, Processor, AgentContext
import logging, os


class Exclamation(Processor):
def init(self, config):
print("init", config)
def init(self, config, context: AgentContext):
print("init", config, context)
self.secret_value = config["secret_value"]
self.context = context

def process(self, record):
logging.info("Processing record" + str(record))
directory = self.context.get_persistent_state_directory()
counter_file = os.path.join(directory, "counter.txt")
counter = 0
if os.path.exists(counter_file):
with open(counter_file, "r") as f:
counter = int(f.read())
counter += 1
else:
counter = 1
with open(counter_file, 'w') as file:
file.write(str(counter))


if self.secret_value == "super secret value - changed":
assert counter == 2
else:
assert counter == 1
return [
SimpleRecord(
record.value() + "!!" + self.secret_value, headers=record.headers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
# limitations under the License.
#

from .api import (
Agent,
Record,
RecordType,
Sink,
Source,
Processor,
)
from .api import Agent, Record, RecordType, Sink, Source, Processor, AgentContext
from .util import SimpleRecord, AvroValue

__all__ = [
Expand All @@ -34,4 +27,5 @@
"Processor",
"SimpleRecord",
"AvroValue",
"AgentContext",
]
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"Source",
"Sink",
"Processor",
"AgentContext",
]


Expand Down Expand Up @@ -61,10 +62,23 @@ def headers(self) -> List[Tuple[str, Any]]:
RecordType = Union[Record, dict, list, tuple]


class AgentContext(ABC):
"""The Agent context interface"""

def __init__(self):
"""Initialize the agent context."""
pass

@abstractmethod
def get_persistent_state_directory(self):
"""Return a path pointing to the stateful agent directory. Return None if not configured in the agent."""
pass


class Agent(ABC):
"""The Agent interface"""

def init(self, config: Dict[str, Any]):
def init(self, config: Dict[str, Any], context: AgentContext):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a breaking change ?

maybe we should have same approach as in Java and have a "setAgentContext" function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it is not, I added a test for the old syntax

"""Initialize the agent from the given configuration."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
datefmt="%H:%M:%S",
)

if len(sys.argv) != 3:
print("Missing gRPC target and python class name")
print("usage: python -m langstream_grpc <target> <config>")
if len(sys.argv) != 4:
print("Missing gRPC target or config or agent context")
print(
"usage: python -m langstream_grpc <target> <config> <agent_context_config>"
)
sys.exit(1)

server = AgentServer(sys.argv[1], sys.argv[2])
server = AgentServer(sys.argv[1], sys.argv[2], sys.argv[3])
server.start()
server.grpc_server.wait_for_termination()
server.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"Source",
"Sink",
"Processor",
"AgentContext",
]


Expand Down Expand Up @@ -61,10 +62,22 @@ def headers(self) -> List[Tuple[str, Any]]:
RecordType = Union[Record, dict, list, tuple]


class AgentContext(ABC):
"""The Agent context interface"""

def __init__(self):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the agent context."""
pass

def get_persistent_state_directory(self):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
"""Return a path pointing to the stateful agent directory. Return None if not configured in the agent."""
return None


class Agent(ABC):
"""The Agent interface"""

def init(self, config: Dict[str, Any]):
def init(self, config: Dict[str, Any], context: AgentContext):
"""Initialize the agent from the given configuration."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import fastavro
import grpc
import inspect

from langstream_grpc.proto import agent_pb2_grpc
from langstream_grpc.proto.agent_pb2 import (
Expand All @@ -43,13 +44,7 @@
SinkResponse,
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceServicer
from .api import (
Source,
Sink,
Processor,
Record,
Agent,
)
from .api import Source, Sink, Processor, Record, Agent, AgentContext
from .util import SimpleRecord, AvroValue


Expand Down Expand Up @@ -324,27 +319,44 @@ def to_grpc_value(self, value) -> Tuple[Optional[Schema], Optional[Value]]:
def call_method_if_exists(klass, method, *args, **kwargs):
method = getattr(klass, method, None)
if callable(method):
return method(*args, **kwargs)
defined_positional_parameters_count = len(inspect.signature(method).parameters)
if defined_positional_parameters_count >= len(args):
return method(*args, **kwargs)
else:
return method(*args[:defined_positional_parameters_count], **kwargs)
return None


def init_agent(configuration) -> Agent:
def init_agent(configuration, context) -> Agent:
full_class_name = configuration["className"]
class_name = full_class_name.split(".")[-1]
module_name = full_class_name[: -len(class_name) - 1]
module = importlib.import_module(module_name)
agent = getattr(module, class_name)()
call_method_if_exists(agent, "init", configuration)
context_impl = AgentContextImpl(configuration, context)
call_method_if_exists(agent, "init", configuration, context_impl)
return agent


class AgentContextImpl(AgentContext):
nicoloboschi marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, configuration: dict, context: dict):
self.configuration = configuration
self.context = context

def get_persistent_state_directory(self) -> str:
dir = self.context.get("persistentStateDirectory", "")
if not dir:
return None
return dir


class AgentServer(object):
def __init__(self, target: str, config: str):
def __init__(self, target: str, config: str, context: str):
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
self.target = target
self.grpc_server = grpc.server(self.thread_pool)
self.port = self.grpc_server.add_insecure_port(target)
self.agent = init_agent(json.loads(config))
self.agent = init_agent(json.loads(config), json.loads(context))

def start(self):
call_method_if_exists(self.agent, "start")
Expand Down
Loading
Loading