Skip to content

Commit

Permalink
Fix gateway with Pulsar (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Oct 13, 2023
1 parent 1b5a8c1 commit 7b9c244
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.code.SimpleRecord;
import ai.langstream.api.runner.topics.OffsetPerPartition;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.runner.topics.TopicOffsetPosition;
Expand All @@ -44,7 +43,6 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
Expand Down Expand Up @@ -358,7 +356,7 @@ protected void sendEvent(EventRecord.Types type, AuthenticatedGatewayRequestCont
topicConnectionsRuntime.createProducer(
"langstream-events",
streamingCluster,
Map.of("topic", gateway.getEventsTopic())); ) {
Map.of("topic", gateway.getEventsTopic()))) {
producer.start();

final EventSources.GatewaySource source =
Expand Down Expand Up @@ -491,12 +489,12 @@ private static Map<String, String> computeMessageHeaders(Record record) {
return messageHeaders;
}

private static String computeOffset(TopicReadResult readResult) throws JsonProcessingException {
final OffsetPerPartition offsetPerPartition = readResult.partitionsOffsets();
if (offsetPerPartition == null) {
private static String computeOffset(TopicReadResult readResult) {
final byte[] offset = readResult.offset();
if (offset == null) {
return null;
}
return Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(offsetPerPartition));
return Base64.getEncoder().encodeToString(offset);
}

protected static List<Function<Record, Boolean>> createMessageFilters(
Expand Down Expand Up @@ -551,15 +549,15 @@ protected void setupReader(
.getTopicConnectionsRuntime(streamingCluster)
.asTopicConnectionsRuntime();

topicConnectionsRuntime.init(streamingCluster);

final String positionParameter = options.getOrDefault("position", "latest");
TopicOffsetPosition position =
switch (positionParameter) {
case "latest" -> TopicOffsetPosition.LATEST;
case "earliest" -> TopicOffsetPosition.EARLIEST;
default -> TopicOffsetPosition.absolute(
new String(
Base64.getDecoder().decode(positionParameter),
StandardCharsets.UTF_8));
Base64.getDecoder().decode(positionParameter));
};
TopicReader reader =
topicConnectionsRuntime.createReader(
Expand Down Expand Up @@ -589,6 +587,8 @@ protected void setupProducer(
.getTopicConnectionsRuntime(streamingCluster)
.asTopicConnectionsRuntime();

topicConnectionsRuntime.init(streamingCluster);

final TopicProducer producer =
topicConnectionsRuntime.createProducer(
null, streamingCluster, Map.of("topic", topic));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package ai.langstream.api.runner.topics;

public record TopicOffsetPosition(Position position, String offset) {
public record TopicOffsetPosition(Position position, byte[] offset) {

public enum Position {
Latest,
Expand All @@ -27,7 +27,7 @@ public enum Position {
public static final TopicOffsetPosition EARLIEST =
new TopicOffsetPosition(Position.Earliest, null);

public static TopicOffsetPosition absolute(String offset) {
public static TopicOffsetPosition absolute(byte[] offset) {
return new TopicOffsetPosition(Position.Absolute, offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ public interface TopicReadResult {

List<Record> records();

OffsetPerPartition partitionsOffsets();
byte[] offset();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@
import ai.langstream.api.runner.topics.TopicReadResult;
import ai.langstream.api.runner.topics.TopicReader;
import ai.langstream.api.util.ClassloaderUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;

@Slf4j
Expand All @@ -56,20 +56,18 @@ public KafkaReaderWrapper(
}

@Override
public void start() {
public void start() throws IOException {
try (var context =
ClassloaderUtils.withContextClassloader(this.getClass().getClassLoader())) {
consumer = new KafkaConsumer(configuration);
consumer = new KafkaConsumer<>(configuration);
}
final List<TopicPartition> partitions =
((List<PartitionInfo>) consumer.partitionsFor(topicName))
.stream()
.map(
partitionInfo ->
new TopicPartition(
partitionInfo.topic(),
partitionInfo.partition()))
.collect(Collectors.toList());
consumer.partitionsFor(topicName).stream()
.map(
partitionInfo ->
new TopicPartition(
partitionInfo.topic(), partitionInfo.partition()))
.collect(Collectors.toList());
consumer.assign(partitions);
if (initialPosition.position() == TopicOffsetPosition.Position.Latest) {
consumer.seekToEnd(partitions);
Expand Down Expand Up @@ -102,8 +100,7 @@ public void start() {
}
}

@SneakyThrows
private OffsetPerPartition parseOffset() {
private OffsetPerPartition parseOffset() throws IOException {
return mapper.readValue(initialPosition.offset(), OffsetPerPartition.class);
}

Expand All @@ -115,7 +112,7 @@ public void close() {
}

@Override
public TopicReadResult read() {
public TopicReadResult read() throws JsonProcessingException {
ConsumerRecords<?, ?> poll = consumer.poll(Duration.ofSeconds(5));
List<Record> records = new ArrayList<>(poll.count());
for (ConsumerRecord<?, ?> record : poll) {
Expand All @@ -133,15 +130,16 @@ public TopicReadResult read() {
partitions.put(key.partition() + "", topicPartitionLongEntry.getValue() + "");
}
final OffsetPerPartition offsetPerPartition = new OffsetPerPartition(partitions);
byte[] offset = mapper.writeValueAsBytes(offsetPerPartition);
return new TopicReadResult() {
@Override
public List<Record> records() {
return records;
}

@Override
public OffsetPerPartition partitionsOffsets() {
return offsetPerPartition;
public byte[] offset() {
return offset;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.OffsetPerPartition;
import ai.langstream.api.runner.topics.TopicAdmin;
import ai.langstream.api.runner.topics.TopicConnectionsRuntime;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeProvider;
Expand All @@ -36,6 +35,7 @@
import ai.langstream.api.runtime.Topic;
import ai.langstream.pulsar.PulsarClientUtils;
import ai.langstream.pulsar.PulsarTopic;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -51,8 +51,10 @@
import org.apache.pulsar.client.admin.PulsarAdminException;
import org.apache.pulsar.client.api.Consumer;
import org.apache.pulsar.client.api.Message;
import org.apache.pulsar.client.api.MessageId;
import org.apache.pulsar.client.api.Producer;
import org.apache.pulsar.client.api.PulsarClient;
import org.apache.pulsar.client.api.Reader;
import org.apache.pulsar.client.api.Schema;
import org.apache.pulsar.client.api.SubscriptionInitialPosition;
import org.apache.pulsar.client.api.SubscriptionType;
Expand Down Expand Up @@ -100,7 +102,6 @@ public TopicReader createReader(
Map<String, Object> configuration,
TopicOffsetPosition initialPosition) {
Map<String, Object> copy = new HashMap<>(configuration);
final TopicConsumer consumer = createConsumer(null, streamingCluster, configuration);
switch (initialPosition.position()) {
case Earliest -> copy.put(
"subscriptionInitialPosition", SubscriptionInitialPosition.Earliest);
Expand All @@ -109,33 +110,7 @@ public TopicReader createReader(
default -> throw new IllegalArgumentException(
"Unsupported initial position: " + initialPosition.position());
}
return new TopicReader() {
@Override
public void start() throws Exception {
consumer.start();
}

@Override
public void close() throws Exception {
consumer.close();
}

@Override
public TopicReadResult read() throws Exception {
final List<Record> records = consumer.read();
return new TopicReadResult() {
@Override
public List<Record> records() {
return records;
}

@Override
public OffsetPerPartition partitionsOffsets() {
return null;
}
};
}
};
return new PulsarTopicReader(copy, initialPosition);
}

@Override
Expand All @@ -153,7 +128,7 @@ public TopicProducer createProducer(
StreamingCluster streamingCluster,
Map<String, Object> configuration) {
Map<String, Object> copy = new HashMap<>(configuration);
return new PulsarTopicProducer(copy);
return new PulsarTopicProducer<>(copy);
}

@Override
Expand Down Expand Up @@ -382,6 +357,83 @@ public String valueAsString() {
}
}

private class PulsarTopicReader implements TopicReader {
private final Map<String, Object> configuration;
private final MessageId startMessageId;

private Reader<GenericRecord> reader;

private PulsarTopicReader(
Map<String, Object> configuration, TopicOffsetPosition initialPosition) {
this.configuration = configuration;
this.startMessageId =
switch (initialPosition.position()) {
case Earliest -> MessageId.earliest;
case Latest -> MessageId.latest;
case Absolute -> {
try {
yield MessageId.fromByteArray(initialPosition.offset());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
};
}

@Override
public void start() throws Exception {
String topic = (String) configuration.remove("topic");
reader =
client.newReader(Schema.AUTO_CONSUME())
.topic(topic)
.startMessageId(this.startMessageId)
.loadConf(configuration)
.create();
}

@Override
public void close() throws Exception {
if (reader != null) {
reader.close();
}
}

@Override
public TopicReadResult read() throws Exception {
Message<GenericRecord> receive = reader.readNext(1, TimeUnit.SECONDS);
List<Record> records;
byte[] offset;
if (receive != null) {
Object key = receive.getKey();
Object value = receive.getValue().getNativeObject();
if (value instanceof KeyValue<?, ?> kv) {
key = kv.getKey();
value = kv.getValue();
}

final Object finalKey = key;
final Object finalValue = value;
log.info("Received message: {}", receive);
records = List.of(new PulsarConsumerRecord(finalKey, finalValue, receive));
offset = receive.getMessageId().toByteArray();
} else {
records = List.of();
offset = null;
}
return new TopicReadResult() {
@Override
public List<Record> records() {
return records;
}

@Override
public byte[] offset() {
return offset;
}
};
}
}

private class PulsarTopicConsumer implements TopicConsumer {

private final Map<String, Object> configuration;
Expand Down Expand Up @@ -469,7 +521,7 @@ public void start() {
String topic = (String) configuration.remove("topic");
schema = (Schema<K>) configuration.remove("schema");
if (schema == null) {
schema = (Schema) Schema.BYTES;
schema = (Schema) Schema.STRING;
}
producer = client.newProducer(schema).topic(topic).loadConf(configuration).create();
}
Expand Down
Loading

0 comments on commit 7b9c244

Please sign in to comment.