Skip to content

Commit

Permalink
Allow to return dict instead of Records in Python agents
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 21, 2023
1 parent 118260d commit af8243b
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from langstream import SimpleRecord, SingleRecordProcessor
from langstream import SingleRecordProcessor


class Exclamation(SingleRecordProcessor):
Expand All @@ -23,4 +23,4 @@ def init(self, config):
self.secret_value = config["secret_value"]

def process_record(self, record):
return [SimpleRecord(record.value() + "!!" + self.secret_value)]
return [(record.value() + "!!" + self.secret_value,)]
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def read(self) -> List[RecordType]:
"""The Source agent generates records and returns them as list of records.
:returns: the list of records. The records must either respect the Record
API contract (have methods value(), key() and so on) or be tuples/list.
API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
* if you return [{"value": "foo"}] a record Record(value="foo") will be built.
If the records are tuples/list, the framework will automatically construct
Record objects from them with the values in the following order : value, key,
headers, origin, timestamp.
Expand Down Expand Up @@ -138,7 +143,13 @@ def process(
exception.
Eg: [(input_record, RuntimeError("Could not process"))]
When the processing is successful, the output records must either respect the
Record API contract (have methods value(), key() and so on) or be tuples/list.
Record API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
* if you return [(input_record, [{"value": "foo"}])] a record
Record(value="foo") will be built.
If the output records are tuples/list, the framework will automatically
construct Record objects from them with the values in the following order :
value, key, headers, origin, timestamp.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(
self.record_id = record_id


def wrap_in_record(record):
if isinstance(record, tuple) or isinstance(record, list):
return SimpleRecord(*record)
if isinstance(record, dict):
return SimpleRecord(**record)
return record


def handle_requests(
agent: Source,
requests: Iterable[SourceRequest],
Expand Down Expand Up @@ -112,6 +120,7 @@ def read(self, requests: Iterable[SourceRequest], context):
raise op_result[0]
records = self.agent.read()
if len(records) > 0:
records = [wrap_in_record(record) for record in records]
grpc_records = []
for record in records:
schemas, grpc_record = self.to_grpc_record(record)
Expand Down Expand Up @@ -140,7 +149,9 @@ def process(self, requests: Iterable[ProcessorRequest], context):
grpc_result.error = str(result)
else:
for record in result:
schemas, grpc_record = self.to_grpc_record(record)
schemas, grpc_record = self.to_grpc_record(
wrap_in_record(record)
)
for schema in schemas:
yield ProcessorResponse(schema=schema)
grpc_result.records.append(grpc_record)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,33 @@ def test_info(stub):


class MyProcessor(SingleRecordProcessor):
def __init__(self):
self.i = 0

def agent_info(self) -> Dict[str, Any]:
return {"test-info-key": "test-info-value"}

def process_record(self, record: Record) -> List[RecordType]:
if record.origin() == "failing-record":
raise Exception("failure")
return [record]
if isinstance(record.value(), str):
return [record]
if isinstance(record.value(), float):
return [
{
"value": record.value(),
"key": record.key(),
"headers": record.headers(),
"origin": record.origin(),
"timestamp": record.timestamp(),
}
]
return [
(
record.value(),
record.key(),
record.headers(),
record.origin(),
record.timestamp(),
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import queue
import time
from io import BytesIO
from typing import List

Expand All @@ -31,7 +32,7 @@
)
from langstream_grpc.proto.agent_pb2_grpc import AgentServiceStub
from langstream_runtime.api import Record, RecordType, Source
from langstream_runtime.util import SimpleRecord, AvroValue
from langstream_runtime.util import AvroValue, SimpleRecord


@pytest.fixture(autouse=True)
Expand All @@ -52,8 +53,19 @@ def server_and_stub():
def test_read(server_and_stub):
server, stub = server_and_stub

responses: list[SourceResponse]
responses = list(stub.read(iter([])))
stop = False

def requests():
while not stop:
time.sleep(0.1)
yield from ()

responses: list[SourceResponse] = []
i = 0
for response in stub.read(iter(requests())):
responses.append(response)
i += 1
stop = i == 4

response_schema = responses[0]
assert len(response_schema.records) == 0
Expand All @@ -77,14 +89,26 @@ def test_read(server_and_stub):
finally:
fp.close()

response_record = responses[2]
assert len(response_schema.records) == 0
record = response_record.records[0]
assert record.record_id == 2
assert record.value.long_value == 42

response_record = responses[3]
assert len(response_schema.records) == 0
record = response_record.records[0]
assert record.record_id == 3
assert record.value.long_value == 43


def test_commit(server_and_stub):
server, stub = server_and_stub
to_commit = queue.Queue()

def send_commit():
committed = 0
while committed < 2:
while committed < 3:
try:
commit_id = to_commit.get(True, 1)
yield SourceRequest(committed_records=[commit_id])
Expand All @@ -98,8 +122,9 @@ def send_commit():
for record in response.records:
to_commit.put(record.record_id)

assert len(server.agent.committed) == 1
assert len(server.agent.committed) == 2
assert server.agent.committed[0] == server.agent.sent[0]
assert server.agent.committed[1].value() == server.agent.sent[1]["value"]


def test_permanent_failure(server_and_stub):
Expand Down Expand Up @@ -141,7 +166,8 @@ def __init__(self):
value={"field": "test"},
)
),
SimpleRecord(value=42),
{"value": 42},
(43,),
]
self.sent = []
self.committed = []
Expand All @@ -156,7 +182,7 @@ def read(self) -> List[RecordType]:

def commit(self, records: List[Record]):
for record in records:
if record.value() == 42:
if record.value() == 43:
raise Exception("test error")
self.committed.extend(records)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
"Sink",
"Processor",
"CommitCallback",
"TopicConsumer",
"TopicProducer",
]


Expand Down Expand Up @@ -146,7 +144,12 @@ def read(self) -> List[RecordType]:
"""The Source agent generates records and returns them as list of records.
:returns: the list of records. The records must either respect the Record
API contract (have methods value(), key() and so on) or be tuples/list.
API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
* if you return [{"value": "foo"}] a record Record(value="foo") will be built.
If the records are tuples/list, the framework will automatically construct
Record objects from them with the values in the following order : value, key,
headers, origin, timestamp.
Expand Down Expand Up @@ -192,7 +195,13 @@ def process(
exception.
Eg: [(input_record, RuntimeError("Could not process"))]
When the processing is successful, the output records must either respect the
Record API contract (have methods value(), key() and so on) or be tuples/list.
Record API contract (have methods value(), key() and so on) or be a dict or
tuples/list.
If the records are dict, the keys if present shall be "value", "key",
"headers", "origin" and "timestamp".
Eg:
* if you return [(input_record, [{"value": "foo"}])] a record
Record(value="foo") will be built.
If the output records are tuples/list, the framework will automatically
construct Record objects from them with the values in the following order :
value, key, headers, origin, timestamp.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def wrap_in_record(records):
for i, record in enumerate(records):
if isinstance(record, tuple) or isinstance(record, list):
records[i] = SimpleRecord(*record)
if isinstance(record, dict):
records[i] = SimpleRecord(**record)
return records


Expand Down

0 comments on commit af8243b

Please sign in to comment.