Skip to content

Commit

Permalink
#220 - Use EntityOperations in SimpleR2dbcRepository.
Browse files Browse the repository at this point in the history
Original pull request: #287.
  • Loading branch information
mp911de committed Feb 12, 2020
1 parent b5445b8 commit 8313630
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.core.DatabaseClient;
import org.springframework.data.r2dbc.core.R2dbcEntityTemplate;
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
import org.springframework.data.r2dbc.repository.R2dbcRepository;
import org.springframework.data.r2dbc.repository.query.R2dbcQueryMethod;
Expand Down Expand Up @@ -92,8 +93,8 @@ protected Object getTargetRepository(RepositoryInformation information) {
RelationalEntityInformation<?, ?> entityInformation = getEntityInformation(information.getDomainType(),
information);

return getTargetRepositoryViaReflection(information, entityInformation, this.databaseClient, this.converter,
this.dataAccessStrategy);
return getTargetRepositoryViaReflection(information, entityInformation,
new R2dbcEntityTemplate(this.databaseClient, this.dataAccessStrategy), this.converter);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.List;

import org.reactivestreams.Publisher;

import org.springframework.dao.TransientDataAccessResourceException;
import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.core.DatabaseClient;
import org.springframework.data.r2dbc.core.PreparedOperation;
import org.springframework.data.r2dbc.core.R2dbcEntityOperations;
import org.springframework.data.r2dbc.core.R2dbcEntityTemplate;
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
import org.springframework.data.r2dbc.core.StatementMapper;
import org.springframework.data.r2dbc.query.Criteria;
import org.springframework.data.r2dbc.query.Query;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.sql.Functions;
import org.springframework.data.relational.core.sql.Select;
Expand All @@ -38,6 +36,7 @@
import org.springframework.data.relational.core.sql.render.SqlRenderer;
import org.springframework.data.relational.repository.query.RelationalEntityInformation;
import org.springframework.data.repository.reactive.ReactiveCrudRepository;
import org.springframework.data.util.Lazy;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;

Expand All @@ -51,16 +50,45 @@
public class SimpleR2dbcRepository<T, ID> implements ReactiveCrudRepository<T, ID> {

private final RelationalEntityInformation<T, ID> entity;
private final DatabaseClient databaseClient;
private final R2dbcConverter converter;
private final ReactiveDataAccessStrategy accessStrategy;
private final R2dbcEntityOperations entityOperations;
private final Lazy<RelationalPersistentProperty> idProperty;

/**
* Create a new {@link SimpleR2dbcRepository}.
*
* @param entity
* @param entityOperations
* @param converter
* @since 1.1
*/
SimpleR2dbcRepository(RelationalEntityInformation<T, ID> entity, R2dbcEntityOperations entityOperations,
R2dbcConverter converter) {

this.entity = entity;
this.entityOperations = entityOperations;
this.idProperty = Lazy.of(() -> converter //
.getMappingContext() //
.getRequiredPersistentEntity(this.entity.getJavaType()) //
.getRequiredIdProperty());
}

/**
* Create a new {@link SimpleR2dbcRepository}.
*
* @param entity
* @param databaseClient
* @param converter
* @param accessStrategy
*/
public SimpleR2dbcRepository(RelationalEntityInformation<T, ID> entity, DatabaseClient databaseClient,
R2dbcConverter converter, ReactiveDataAccessStrategy accessStrategy) {

this.entity = entity;
this.databaseClient = databaseClient;
this.converter = converter;
this.accessStrategy = accessStrategy;
this.entityOperations = new R2dbcEntityTemplate(databaseClient);
this.idProperty = Lazy.of(() -> converter //
.getMappingContext() //
.getRequiredPersistentEntity(this.entity.getJavaType()) //
.getRequiredIdProperty());
}

/* (non-Javadoc)
Expand All @@ -73,28 +101,10 @@ public <S extends T> Mono<S> save(S objectToSave) {
Assert.notNull(objectToSave, "Object to save must not be null!");

if (this.entity.isNew(objectToSave)) {

return this.databaseClient.insert() //
.into(this.entity.getJavaType()) //
.table(this.entity.getTableName()).using(objectToSave) //
.map(this.converter.populateIdIfNecessary(objectToSave)) //
.first() //
.defaultIfEmpty(objectToSave);
return this.entityOperations.insert(objectToSave);
}

return this.databaseClient.update() //
.table(this.entity.getJavaType()) //
.table(this.entity.getTableName()).using(objectToSave) //
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {

if (rowsUpdated == 0) {
sink.error(new TransientDataAccessResourceException(
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
this.entity.getTableName(), this.entity.getId(objectToSave))));
} else {
sink.next(objectToSave);
}
});
return this.entityOperations.update(objectToSave);
}

/* (non-Javadoc)
Expand Down Expand Up @@ -129,20 +139,7 @@ public Mono<T> findById(ID id) {

Assert.notNull(id, "Id must not be null!");

List<SqlIdentifier> columns = this.accessStrategy.getAllColumns(this.entity.getJavaType());
String idProperty = getIdProperty().getName();

StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType());
StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()) //
.withProjection(columns) //
.withCriteria(Criteria.where(idProperty).is(id));

PreparedOperation<?> operation = mapper.getMappedObject(selectSpec);

return this.databaseClient.execute(operation) //
.as(this.entity.getJavaType()) //
.fetch() //
.one();
return this.entityOperations.selectOne(getIdQuery(id), this.entity.getJavaType());
}

/* (non-Javadoc)
Expand All @@ -161,18 +158,7 @@ public Mono<Boolean> existsById(ID id) {

Assert.notNull(id, "Id must not be null!");

String idProperty = getIdProperty().getName();

StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType());
StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()).withProjection(idProperty) //
.withCriteria(Criteria.where(idProperty).is(id));

PreparedOperation<?> operation = mapper.getMappedObject(selectSpec);

return this.databaseClient.execute(operation) //
.map((r, md) -> r) //
.first() //
.hasElement();
return this.entityOperations.exists(getIdQuery(id), this.entity.getJavaType());
}

/* (non-Javadoc)
Expand All @@ -188,7 +174,7 @@ public Mono<Boolean> existsById(Publisher<ID> publisher) {
*/
@Override
public Flux<T> findAll() {
return this.databaseClient.select().from(this.entity.getJavaType()).fetch().all();
return this.entityOperations.select(Query.empty(), this.entity.getJavaType());
}

/* (non-Javadoc)
Expand Down Expand Up @@ -216,17 +202,9 @@ public Flux<T> findAllById(Publisher<ID> idPublisher) {
return Flux.empty();
}

List<SqlIdentifier> columns = this.accessStrategy.getAllColumns(this.entity.getJavaType());
String idProperty = getIdProperty().getName();

StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType());
StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()) //
.withProjection(columns) //
.withCriteria(Criteria.where(idProperty).in(ids));

PreparedOperation<?> operation = mapper.getMappedObject(selectSpec);

return this.databaseClient.execute(operation).as(this.entity.getJavaType()).fetch().all();
return this.entityOperations.select(Query.query(Criteria.where(idProperty).in(ids)), this.entity.getJavaType());
});
}

Expand All @@ -235,17 +213,7 @@ public Flux<T> findAllById(Publisher<ID> idPublisher) {
*/
@Override
public Mono<Long> count() {

Table table = Table.create(this.accessStrategy.toSql(this.entity.getTableName()));
Select select = StatementBuilder //
.select(Functions.count(table.column(this.accessStrategy.toSql(getIdProperty().getColumnName())))) //
.from(table) //
.build();

return this.databaseClient.execute(SqlRenderer.toString(select)) //
.map((r, md) -> r.get(0, Long.class)) //
.first() //
.defaultIfEmpty(0L);
return this.entityOperations.count(Query.empty(), this.entity.getJavaType());
}

/* (non-Javadoc)
Expand All @@ -257,13 +225,7 @@ public Mono<Void> deleteById(ID id) {

Assert.notNull(id, "Id must not be null!");

return this.databaseClient.delete() //
.from(this.entity.getJavaType()) //
.table(this.entity.getTableName()) //
.matching(Criteria.where(getIdProperty().getName()).is(id)) //
.fetch() //
.rowsUpdated() //
.then();
return this.entityOperations.delete(getIdQuery(id), this.entity.getJavaType()).then();
}

/* (non-Javadoc)
Expand All @@ -274,20 +236,16 @@ public Mono<Void> deleteById(ID id) {
public Mono<Void> deleteById(Publisher<ID> idPublisher) {

Assert.notNull(idPublisher, "The Id Publisher must not be null!");
StatementMapper statementMapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType());

return Flux.from(idPublisher).buffer().filter(ids -> !ids.isEmpty()).concatMap(ids -> {

if (ids.isEmpty()) {
return Flux.empty();
}

return this.databaseClient.delete() //
.from(this.entity.getJavaType()) //
.table(this.entity.getTableName()) //
.matching(Criteria.where(getIdProperty().getName()).in(ids)) //
.fetch() //
.rowsUpdated();
String idProperty = getIdProperty().getName();

return this.entityOperations.delete(Query.query(Criteria.where(idProperty).in(ids)), this.entity.getJavaType());
}).then();
}

Expand Down Expand Up @@ -336,14 +294,14 @@ public Mono<Void> deleteAll(Publisher<? extends T> objectPublisher) {
@Override
@Transactional
public Mono<Void> deleteAll() {
return this.databaseClient.delete().from(this.entity.getTableName()).then();
return this.entityOperations.delete(Query.empty(), this.entity.getJavaType()).then();
}

private RelationalPersistentProperty getIdProperty() {
return this.idProperty.get();
}

return this.converter //
.getMappingContext() //
.getRequiredPersistentEntity(this.entity.getJavaType()) //
.getRequiredIdProperty();
private Query getIdQuery(Object id) {
return Query.query(Criteria.where(getIdProperty().getName()).is(id));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.annotation.Id;
import org.springframework.data.r2dbc.convert.MappingR2dbcConverter;
import org.springframework.data.r2dbc.convert.R2dbcConverter;
import org.springframework.data.r2dbc.core.DatabaseClient;
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
import org.springframework.data.relational.repository.query.RelationalEntityInformation;
import org.springframework.data.relational.repository.support.MappingRelationalEntityInformation;
import org.springframework.data.repository.Repository;
Expand All @@ -41,18 +42,15 @@
@RunWith(MockitoJUnitRunner.class)
public class R2dbcRepositoryFactoryUnitTests {

R2dbcConverter r2dbcConverter = new MappingR2dbcConverter(new R2dbcMappingContext());

@Mock DatabaseClient databaseClient;
@Mock R2dbcConverter r2dbcConverter;
@Mock ReactiveDataAccessStrategy dataAccessStrategy;
@Mock @SuppressWarnings("rawtypes") MappingContext mappingContext;
@Mock @SuppressWarnings("rawtypes") RelationalPersistentEntity entity;

@Before
@SuppressWarnings("unchecked")
public void before() {
when(mappingContext.getRequiredPersistentEntity(Person.class)).thenReturn(entity);
when(dataAccessStrategy.getConverter()).thenReturn(r2dbcConverter);
when(r2dbcConverter.getMappingContext()).thenReturn(mappingContext);
}

@Test
Expand All @@ -75,5 +73,7 @@ public void createsRepositoryWithIdTypeLong() {

interface MyPersonRepository extends Repository<Person, Long> {}

static class Person {}
static class Person {
@Id long id;
}
}

0 comments on commit 8313630

Please sign in to comment.