Skip to content

Commit

Permalink
feat: introduce realtime trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
tchiotludo committed May 2, 2024
1 parent 8dc3e5a commit d6f5e95
Show file tree
Hide file tree
Showing 9 changed files with 410 additions and 180 deletions.
183 changes: 128 additions & 55 deletions src/main/java/io/kestra/plugin/amqp/Consume.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.slf4j.Logger;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.scheduler.Schedulers;

import java.io.*;
import java.net.URI;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;

import static io.kestra.core.utils.Rethrow.throwRunnable;
import static io.kestra.core.utils.Rethrow.throwConsumer;

@SuperBuilder
@ToString
Expand Down Expand Up @@ -56,93 +62,160 @@ public class Consume extends AbstractAmqpConnection implements RunnableTask<Cons

private Duration maxDuration;


@Override
public Consume.Output run(RunContext runContext) throws Exception {
ConnectionFactory factory = this.connectionFactory(runContext);

File tempFile = runContext.tempFile(".ion").toFile();
AtomicInteger total = new AtomicInteger();
ZonedDateTime started = ZonedDateTime.now();

File tempFile = runContext.tempFile(".ion").toFile();

if (this.maxDuration == null && this.maxRecords == null) {
throw new Exception("maxDuration or maxRecords must be set to avoid infinite loop");
}

try (BufferedOutputStream outputFile = new BufferedOutputStream(new FileOutputStream(tempFile));
Connection connection = factory.newConnection()) {
ConnectionFactory factory = this.connectionFactory(runContext);

try (
BufferedOutputStream outputFile = new BufferedOutputStream(new FileOutputStream(tempFile));
ConsumeThread thread = new ConsumeThread(
factory,
runContext,
this,
throwConsumer(message -> {
FileSerde.write(outputFile, message);
total.getAndIncrement();
}),
() -> this.ended(total, started)
);
) {
thread.start();
thread.join();

if (thread.getException() != null) {
throw thread.getException();
}

runContext.metric(Counter.of("records", total.get(), "queue", runContext.render(this.queue)));
outputFile.flush();

return Output.builder()
.uri(runContext.storage().putFile(tempFile))
.count(total.get())
.build();
}
}

public Publisher<Message> stream(RunContext runContext) {
return Flux.<Message>create(
fluxSink -> {
try {
ConnectionFactory factory = this.connectionFactory(runContext);

try (
ConsumeThread thread = new ConsumeThread(
factory,
runContext,
this,
throwConsumer(fluxSink::next),
() -> false
);
) {
thread.start();
thread.join();
}
} catch (Throwable e) {
fluxSink.error(e);
} finally {
fluxSink.complete();
}
},
FluxSink.OverflowStrategy.BUFFER
)
.subscribeOn(Schedulers.boundedElastic());
}

@SuppressWarnings("RedundantIfStatement")
private boolean ended(AtomicInteger count, ZonedDateTime start) {
if (this.maxRecords != null && count.get() >= this.maxRecords) {
return true;
}
if (this.maxDuration != null && ZonedDateTime.now().toEpochSecond() > start.plus(this.maxDuration).toEpochSecond()) {
return true;
}

return false;
}

public static class ConsumeThread extends Thread implements AutoCloseable {
private final AtomicReference<Long> lastDeliveryTag = new AtomicReference<>();
private final AtomicReference<Exception> exception = new AtomicReference<>();
private final Supplier<Boolean> endSupplier;

private final ConnectionFactory factory;
private final RunContext runContext;
private final ConsumeBaseInterface consumeInterface;
private final Consumer<Message> consumer;

private Connection connection;
private Channel channel;

public ConsumeThread(ConnectionFactory factory, RunContext runContext, ConsumeBaseInterface consumeInterface, Consumer<Message> consumer, Supplier<Boolean> supplier) {
super("amqp-consume");
this.setDaemon(true);
this.factory = factory;
this.runContext = runContext;
this.consumeInterface = consumeInterface;
this.consumer = consumer;
this.endSupplier = supplier;
}

Channel channel = connection.createChannel();
public Exception getException() {
return this.exception.get();
}

AtomicReference<Long> lastDeliveryTag = new AtomicReference<>();
AtomicReference<Exception> threadException = new AtomicReference<>();
@Override
public void run() {
try {
connection = factory.newConnection();
channel = connection.createChannel();

Thread thread = new Thread(throwRunnable(() -> {
channel.basicConsume(
runContext.render(this.queue),
runContext.render(consumeInterface.getQueue()),
false,
this.consumerTag,
runContext.render(consumeInterface.getConsumerTag()),
(consumerTag, message) -> {
Message msg = null;
try {
msg = Message.of(message.getBody(), serdeType, message.getProperties());
consumer.accept(Message.of(message.getBody(), consumeInterface.getSerdeType(), message.getProperties()));
lastDeliveryTag.set(message.getEnvelope().getDeliveryTag());
} catch (Exception e) {
threadException.set(e);
exception.set(e);
}
FileSerde.write(outputFile, msg);
total.getAndIncrement();

lastDeliveryTag.set(message.getEnvelope().getDeliveryTag());

},
(consumerTag) -> {
},
(consumerTag1, sig) -> {
}
);
}));
thread.setDaemon(true);
thread.setName("amqp-consume");
thread.start();

while (!this.ended(total, started)) {
if (threadException.get() != null) {
channel.basicCancel(this.consumerTag);
channel.close();
thread.join();
throw threadException.get();
// keep thread running
while (exception != null && !endSupplier.get()) {
Thread.sleep(100);
}
Thread.sleep(100);
} catch (Exception e) {
exception.set(e);
}
channel.basicCancel(this.consumerTag);
}

@Override
public void close() throws Exception {
channel.basicCancel(runContext.render(consumeInterface.getConsumerTag()));

if (lastDeliveryTag.get() != null) {
channel.basicAck(lastDeliveryTag.get(), true);
}

channel.close();
thread.join();

runContext.metric(Counter.of("records", total.get(), "queue", runContext.render(this.queue)));
outputFile.flush();
}
return Output.builder()
.uri(runContext.putTempFile(tempFile))
.count(total.get())
.build();
}

@SuppressWarnings("RedundantIfStatement")
private boolean ended(AtomicInteger count, ZonedDateTime start) {
if (this.maxRecords != null && count.get() >= this.maxRecords) {
return true;
connection.close();
}
if (this.maxDuration != null && ZonedDateTime.now().toEpochSecond() > start.plus(this.maxDuration).toEpochSecond()) {
return true;
}

return false;
}

@Builder
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/io/kestra/plugin/amqp/ConsumeBaseInterface.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.kestra.plugin.amqp;

import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.plugin.amqp.models.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;

public interface ConsumeBaseInterface {
@NotNull
@PluginProperty(dynamic = true)
@Schema(
title = "The queue to pull messages from."
)
String getQueue();

@PluginProperty(dynamic = true)
@Schema(
title = "A client-generated consumer tag to establish context."
)
@NotNull
String getConsumerTag();

@NotNull
SerdeType getSerdeType();
}
21 changes: 1 addition & 20 deletions src/main/java/io/kestra/plugin/amqp/ConsumeInterface.java
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
package io.kestra.plugin.amqp;

import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.plugin.amqp.models.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;

import jakarta.validation.constraints.NotNull;
import java.time.Duration;

public interface ConsumeInterface {
@NotNull
@PluginProperty(dynamic = true)
@Schema(
title = "The queue to pull messages from."
)
String getQueue();

@Schema(
title = "A client-generated consumer tag to establish context."
)
@NotNull
String getConsumerTag();

public interface ConsumeInterface extends ConsumeBaseInterface {
@Schema(
title = "The maximum number of rows to fetch before stopping.",
description = "It's not an hard limit and is evaluated every second."
Expand All @@ -32,7 +16,4 @@ public interface ConsumeInterface {
description = "It's not an hard limit and is evaluated every second."
)
Duration getMaxDuration();

@NotNull
SerdeType getSerdeType();
}
79 changes: 79 additions & 0 deletions src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package io.kestra.plugin.amqp;

import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.conditions.ConditionContext;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.triggers.*;
import io.kestra.plugin.amqp.models.Message;
import io.kestra.plugin.amqp.models.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;

import java.time.Duration;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "React to and consume messages from an AMQP queue creating one executions for each message."
)
@Plugin(
examples = {
@Example(
code = {
"url: amqp://guest:guest@localhost:5672/my_vhost",
"maxRecords: 2",
"queue: amqpTrigger.queue"
}
)
}
)
public class RealtimeTrigger extends AbstractTrigger implements RealtimeTriggerInterface, TriggerOutput<Message>, ConsumeBaseInterface, AmqpConnectionInterface {
@Builder.Default
private final Duration interval = Duration.ofSeconds(60);

private String url;
private String host;
private String port;
private String username;
private String password;
private String virtualHost;

private String queue;

@Builder.Default
private String consumerTag = "Kestra";

private Integer maxRecords;

private Duration maxDuration;

@Builder.Default
private SerdeType serdeType = SerdeType.STRING;

@Override
public Publisher<Execution> evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception {
Consume task = Consume.builder()
.url(this.url)
.host(this.host)
.port(this.port)
.username(this.username)
.password(this.password)
.virtualHost(this.virtualHost)
.queue(this.queue)
.consumerTag(this.consumerTag)
.maxRecords(this.maxRecords)
.maxDuration(this.maxDuration)
.serdeType(this.serdeType)
.build();

return Flux.from(task.stream(conditionContext.getRunContext()))
.map((record) -> TriggerService.generateRealtimeExecution(this, context, record));
}
}
2 changes: 1 addition & 1 deletion src/main/java/io/kestra/plugin/amqp/models/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@Value
@Builder

public class Message {
public class Message implements io.kestra.core.models.tasks.Output {
String contentType;
String contentEncoding;
Map<String, Object> headers;
Expand Down
Loading

0 comments on commit d6f5e95

Please sign in to comment.