Skip to content

Commit

Permalink
feat(jdbc): handle deadlock retry
Browse files Browse the repository at this point in the history
  • Loading branch information
tchiotludo committed Jun 17, 2022
1 parent 63aab8f commit 89f6bc2
Show file tree
Hide file tree
Showing 17 changed files with 176 additions and 73 deletions.
14 changes: 9 additions & 5 deletions core/src/main/java/io/kestra/core/runners/FlowListeners.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,18 @@ private synchronized void upsert(Flow flow) {
}

private void notifyConsumers() {
this.consumers
.forEach(consumer -> consumer.accept(new ArrayList<>(this.flows)));
synchronized (this) {
this.consumers
.forEach(consumer -> consumer.accept(new ArrayList<>(this.flows)));
}
}

@Override
public void listen(Consumer<List<Flow>> consumer) {
consumers.add(consumer);
consumer.accept(new ArrayList<>(this.flows()));
public synchronized void listen(Consumer<List<Flow>> consumer) {
synchronized (this) {
consumers.add(consumer);
consumer.accept(new ArrayList<>(this.flows()));
}
}

@SneakyThrows
Expand Down
36 changes: 26 additions & 10 deletions core/src/main/java/io/kestra/core/utils/RetryUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

import jakarta.inject.Singleton;

@Singleton
Expand All @@ -31,10 +32,10 @@ public <T, E extends Throwable> Instance<T, E> of(AbstractRetry policy) {
.build();
}

public <T, E extends Throwable> Instance<T, E> of(AbstractRetry policy, Supplier<E> failureSupplier) {
public <T, E extends Throwable> Instance<T, E> of(AbstractRetry policy, Function<RetryFailed, E> failureFunction) {
return Instance.<T, E>builder()
.policy(policy)
.failureSupplier(failureSupplier)
.failureFunction(failureFunction)
.build();
}

Expand All @@ -60,13 +61,13 @@ public static class Instance<T, E extends Throwable> {
@Builder.Default
private final Logger logger = log;

private final Supplier<E> failureSupplier;
private final Function<RetryFailed, E> failureFunction;

public T run(Class<E> exception, CheckedSupplier<T> run) throws E {
return wrap(
Failsafe
.with(
this.exceptionFallback(this.failureSupplier)
this.exceptionFallback(this.failureFunction)
.handle(exception),
this.toPolicy(this.policy)
.handle(exception)
Expand All @@ -79,7 +80,7 @@ public T run(List<Class<? extends Throwable>> list, CheckedSupplier<T> run) thro
return wrap(
Failsafe
.with(
this.exceptionFallback(this.failureSupplier)
this.exceptionFallback(this.failureFunction)
.handleIf((t, throwable) -> list.stream().anyMatch(cls -> cls.isInstance(throwable))),
this.toPolicy(this.policy)
.handleIf((t, throwable) -> list.stream().anyMatch(cls -> cls.isInstance(throwable)))
Expand All @@ -88,11 +89,24 @@ public T run(List<Class<? extends Throwable>> list, CheckedSupplier<T> run) thro
);
}

public T runRetryIf(Predicate<? extends E> predicate, CheckedSupplier<T> run) {
return wrap(
Failsafe
.with(
this.exceptionFallback(this.failureFunction)
.handleIf(predicate),
this.toPolicy(this.policy)
.handleIf(predicate)
),
run
);
}

public T run(BiPredicate<T, Throwable> predicate, CheckedSupplier<T> run) throws E {
return wrap(
Failsafe
.with(
this.exceptionFallback(this.failureSupplier)
this.exceptionFallback(this.failureFunction)
.handleIf(predicate),
this.toPolicy(this.policy)
.handleIf(predicate)
Expand All @@ -105,7 +119,7 @@ public T run(Predicate<T> predicate, CheckedSupplier<T> run) throws E {
return wrap(
Failsafe
.with(
this.exceptionFallback(this.failureSupplier)
this.exceptionFallback(this.failureFunction)
.handleResultIf(predicate),
this.toPolicy(this.policy)
.handleResultIf(predicate)
Expand All @@ -123,10 +137,12 @@ private static <T, E extends Throwable> T wrap(FailsafeExecutor<T> failsafeExecu
}
}

private Fallback<T> exceptionFallback(Supplier<E> failureSupplier) throws E {
private Fallback<T> exceptionFallback(Function<RetryFailed, E> failureFunction) {
return Fallback
.ofException((ExecutionAttemptedEvent<? extends T> executionAttemptedEvent) -> {
throw failureSupplier != null ? failureSupplier.get() : new RetryFailed(executionAttemptedEvent);
RetryFailed retryFailed = new RetryFailed(executionAttemptedEvent);

throw failureFunction != null ? failureFunction.apply(retryFailed) : retryFailed;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void trigger() throws InterruptedException, TimeoutException {
assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));

// trigger is done
countDownLatch.await(5, TimeUnit.SECONDS);
countDownLatch.await(10, TimeUnit.SECONDS);
assertThat(ended.size(), is(3));

Execution triggerExecution = ended.entrySet()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.kestra.repository.mysql;

import io.kestra.core.models.DeletedInterface;
import io.kestra.core.repositories.ArrayListTotal;
import io.kestra.jdbc.AbstractJdbcRepository;
import io.micronaut.context.ApplicationContext;
Expand Down Expand Up @@ -40,7 +39,7 @@ public <R extends Record, E> ArrayListTotal<E> fetchPage(DSLContext context, Sel
.fetch()
.map(mapper);

return dslContext.transactionResult(configuration -> new ArrayListTotal<>(
return dslContextWrapper.transactionResult(configuration -> new ArrayListTotal<>(
map,
DSL.using(configuration).fetchOne("SELECT FOUND_ROWS()").into(Integer.class)
));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.kestra.repository.postgres;

import io.kestra.core.models.DeletedInterface;
import io.kestra.core.repositories.ArrayListTotal;
import io.kestra.jdbc.AbstractJdbcRepository;
import io.micronaut.context.ApplicationContext;
Expand Down Expand Up @@ -59,7 +58,7 @@ public void persist(T entity, DSLContext context, @Nullable Map<Field<Object>,
@SuppressWarnings("unchecked")
public <R extends Record, E> ArrayListTotal<E> fetchPage(DSLContext context, SelectConditionStep<R> select, Pageable pageable, RecordMapper<R, E> mapper) {
Result<Record> results = this.limit(
this.dslContext.select(DSL.asterisk(), DSL.count().over().as("total_count"))
context.select(DSL.asterisk(), DSL.count().over().as("total_count"))
.from(this
.sort(select, pageable)
.asTable("page")
Expand Down
8 changes: 4 additions & 4 deletions jdbc/src/main/java/io/kestra/jdbc/AbstractJdbcRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public abstract class AbstractJdbcRepository<T> {
protected final Class<T> cls;

@Getter
protected final DSLContext dslContext;
protected final DSLContextWrapper dslContextWrapper;

@Getter
protected final Table<Record> table;
Expand All @@ -38,7 +38,7 @@ public AbstractJdbcRepository(
) {
this.cls = cls;
this.queueService = applicationContext.getBean(QueueService.class);
this.dslContext = applicationContext.getBean(DSLContext.class);
this.dslContextWrapper = applicationContext.getBean(DSLContextWrapper.class);

JdbcConfiguration jdbcConfiguration = applicationContext.getBean(JdbcConfiguration.class);

Expand Down Expand Up @@ -69,7 +69,7 @@ public void persist(T entity) {
}

public void persist(T entity, Map<Field<Object>, Object> fields) {
dslContext.transaction(configuration ->
dslContextWrapper.transaction(configuration ->
this.persist(entity, DSL.using(configuration), fields)
);
}
Expand All @@ -87,7 +87,7 @@ public void persist(T entity, DSLContext dslContext, Map<Field<Object>, Object>
}

public void delete(T entity) {
dslContext.transaction(configuration -> {
dslContextWrapper.transaction(configuration -> {
this.delete(DSL.using(configuration), entity);
});
}
Expand Down
83 changes: 83 additions & 0 deletions jdbc/src/main/java/io/kestra/jdbc/DSLContextWrapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package io.kestra.jdbc;

import io.kestra.core.models.tasks.retrys.Exponential;
import io.kestra.core.utils.RetryUtils;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.jooq.DSLContext;
import org.jooq.TransactionalCallable;
import org.jooq.TransactionalRunnable;

import java.sql.SQLException;
import java.time.Duration;
import java.util.function.Predicate;

@Singleton
public class DSLContextWrapper {
private final DSLContext dslContext;

private final RetryUtils retryUtils;

@Inject
public DSLContextWrapper(DSLContext dslContext, RetryUtils retryUtils) {
this.dslContext = dslContext;
this.retryUtils = retryUtils;
}

private <T> RetryUtils.Instance<T, RuntimeException> retryer() {
return retryUtils.of(
Exponential.builder()
.interval(Duration.ofMillis(10))
.maxAttempt(10)
.maxInterval(Duration.ofMillis(100))
.build()
);
}

private static <E extends Throwable> Predicate<E> predicate() {
return (e) -> {
if (!(e.getCause() instanceof SQLException)) {
return false;
}

SQLException cause = (SQLException) e.getCause();

return
// standard deadlock
cause.getSQLState().equals("40001") ||
// postgres deadlock
cause.getSQLState().equals("40P01");
};
}

public void transaction(TransactionalRunnable transactional) {
RetryUtils.Instance<Object, Throwable> of = retryUtils.of(Exponential.builder()
.interval(Duration.ofMillis(10))
.maxAttempt(10)
.maxInterval(Duration.ofMillis(100))
.build()
);

this.<Void>retryer().runRetryIf(
predicate(),
() -> {
dslContext.transaction(transactional);
return null;
}
);
}

public <T> T transactionResult(TransactionalCallable<T> transactional) {
RetryUtils.Instance<Object, Throwable> of = retryUtils.of(Exponential.builder()
.interval(Duration.ofMillis(10))
.maxAttempt(10)
.maxInterval(Duration.ofMillis(100))
.build()
);

return this.<T>retryer().runRetryIf(
predicate(),
() -> dslContext.transactionResult(transactional)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import jakarta.inject.Singleton;
import org.apache.commons.lang3.tuple.Pair;
import org.jooq.*;
import org.jooq.exception.DataAccessException;
import org.jooq.impl.DSL;

import java.sql.SQLTransientException;
import java.time.Duration;
import java.time.LocalDate;
import java.time.ZonedDateTime;
Expand All @@ -42,7 +44,7 @@ public AbstractExecutionRepository(AbstractJdbcRepository<Execution> jdbcReposit
@Override
public Optional<Execution> findById(String id) {
return jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
Select<Record1<Object>> from = DSL
.using(configuration)
Expand All @@ -67,7 +69,7 @@ public ArrayListTotal<Execution> find(
@Nullable List<State.Type> state
) {
return this.jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
DSLContext context = DSL.using(configuration);

Expand Down Expand Up @@ -110,7 +112,7 @@ public ArrayListTotal<Execution> find(
@Override
public ArrayListTotal<Execution> findByFlowId(String namespace, String id, Pageable pageable) {
return this.jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
DSLContext context = DSL.using(configuration);

Expand Down Expand Up @@ -203,7 +205,7 @@ private Results dailyStatisticsQuery(List<Field<?>> fields, String query, ZonedD
));

return jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
SelectConditionStep<?> select = DSL
.using(configuration)
Expand Down Expand Up @@ -355,10 +357,11 @@ public List<ExecutionCount> executionCounts(
ZonedDateTime finalEndDate = endDate == null ? ZonedDateTime.now() : endDate;

List<ExecutionCount> result = this.jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
SelectConditionStep<?> select = this.jdbcRepository
.getDslContext()
DSLContext dslContext = DSL.using(configuration);

SelectConditionStep<?> select = dslContext
.select(List.of(
DSL.field("namespace"),
DSL.field("flow_id"),
Expand Down Expand Up @@ -440,7 +443,7 @@ public Execution save(DSLContext dslContext, Execution execution) {

public Executor lock(String executionId, Function<Pair<Execution, JdbcExecutorState>, Pair<Executor, JdbcExecutorState>> function) {
return this.jdbcRepository
.getDslContext()
.getDslContextWrapper()
.transactionResult(configuration -> {
DSLContext context = DSL.using(configuration);

Expand Down
Loading

0 comments on commit 89f6bc2

Please sign in to comment.