Skip to content

Commit

Permalink
Add support for $covariancePop and $covarianceSamp aggregation ex…
Browse files Browse the repository at this point in the history
…pressions.

Closes: #3712
Original pull request: #3740.
  • Loading branch information
christophstrobl authored and mp911de committed Aug 23, 2021
1 parent f9f4c46 commit c574e5c
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,63 @@ public StdDevSamp stdDevSamp() {
return usesFieldRef() ? StdDevSamp.stdDevSampOf(fieldReference) : StdDevSamp.stdDevSampOf(expression);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the population covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(String fieldReference) {
return covariancePop().and(fieldReference);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the population covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(AggregationExpression expression) {
return covariancePop().and(expression);
}

private CovariancePop covariancePop() {
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the sample covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(String fieldReference) {
return covarianceSamp().and(fieldReference);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
return covarianceSamp().and(expression);
}

private CovarianceSamp covarianceSamp() {
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
: CovarianceSamp.covarianceSampOf(expression);
}

private boolean usesFieldRef() {
return fieldReference != null;
}
Expand Down Expand Up @@ -658,4 +715,124 @@ public Document toDocument(Object value, AggregationOperationContext context) {
return super.toDocument(value, context);
}
}

/**
* {@link AggregationExpression} for {@code $covariancePop}.
*
* @author Christoph Strobl
* @since 3.3
*/
public static class CovariancePop extends AbstractAggregationExpression {

private CovariancePop(Object value) {
super(value);
}

/**
* Creates new {@link CovariancePop}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public static CovariancePop covariancePopOf(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null!");
return new CovariancePop(asFields(fieldReference));
}

/**
* Creates new {@link CovariancePop}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public static CovariancePop covariancePopOf(AggregationExpression expression) {
return new CovariancePop(Collections.singletonList(expression));
}

/**
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public CovariancePop and(String fieldReference) {
return new CovariancePop(append(asFields(fieldReference)));
}

/**
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
*/
public CovariancePop and(AggregationExpression expression) {
return new CovariancePop(append(expression));
}

@Override
protected String getMongoMethod() {
return "$covariancePop";
}
}

/**
* {@link AggregationExpression} for {@code $covarianceSamp}.
*
* @author Christoph Strobl
* @since 3.3
*/
public static class CovarianceSamp extends AbstractAggregationExpression {

private CovarianceSamp(Object value) {
super(value);
}

/**
* Creates new {@link CovarianceSamp}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public static CovarianceSamp covarianceSampOf(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null!");
return new CovarianceSamp(asFields(fieldReference));
}

/**
* Creates new {@link CovarianceSamp}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public static CovarianceSamp covarianceSampOf(AggregationExpression expression) {
return new CovarianceSamp(Collections.singletonList(expression));
}

/**
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public CovarianceSamp and(String fieldReference) {
return new CovarianceSamp(append(asFields(fieldReference)));
}

/**
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovarianceSamp}.
*/
public CovarianceSamp and(AggregationExpression expression) {
return new CovarianceSamp(append(expression));
}

@Override
protected String getMongoMethod() {
return "$covarianceSamp";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.List;

import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Avg;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
Expand Down Expand Up @@ -511,6 +513,63 @@ public StdDevSamp stdDevSamp() {
: AccumulatorOperators.StdDevSamp.stdDevSampOf(expression);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the population covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(String fieldReference) {
return covariancePop().and(fieldReference);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the population covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovariancePop covariancePop(AggregationExpression expression) {
return covariancePop().and(expression);
}

private CovariancePop covariancePop() {
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
* field to calculate the sample covariance of the two.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(String fieldReference) {
return covarianceSamp().and(fieldReference);
}

/**
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link CovariancePop}.
* @since 3.3
*/
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
return covarianceSamp().and(expression);
}

private CovarianceSamp covarianceSamp() {
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
: CovarianceSamp.covarianceSampOf(expression);
}

/**
* Creates new {@link AggregationExpression} that rounds a number to a whole integer or to a specified decimal
* place.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ public class MethodReferenceNode extends ExpressionNode {
map.put("addToSet", singleArgRef().forOperator("$addToSet"));
map.put("stdDevPop", arrayArgRef().forOperator("$stdDevPop"));
map.put("stdDevSamp", arrayArgRef().forOperator("$stdDevSamp"));
map.put("covariancePop", arrayArgRef().forOperator("$covariancePop"));
map.put("covarianceSamp", arrayArgRef().forOperator("$covarianceSamp"));

// TYPE OPERATORS
map.put("type", singleArgRef().forOperator("$type"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.util.aggregation;

import org.bson.Document;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
import org.springframework.data.mongodb.core.aggregation.Field;
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.lang.Nullable;

/**
* @author Christoph Strobl
*/
public class TestAggregationContext implements AggregationOperationContext {

private final AggregationOperationContext delegate;

private TestAggregationContext(AggregationOperationContext delegate) {
this.delegate = delegate;
}

public static AggregationOperationContext contextFor(@Nullable Class<?> type) {

MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE,
new MongoMappingContext());
mongoConverter.afterPropertiesSet();

return contextFor(type, mongoConverter);
}

public static AggregationOperationContext contextFor(@Nullable Class<?> type, MongoConverter mongoConverter) {

if (type == null) {
return Aggregation.DEFAULT_CONTEXT;
}

return new TestAggregationContext(new TypeBasedAggregationOperationContext(type, mongoConverter.getMappingContext(),
new QueryMapper(mongoConverter)).continueOnMissingFieldReference());
}

@Override
public Document getMappedObject(Document document, @Nullable Class<?> type) {
return delegate.getMappedObject(document, type);
}

@Override
public FieldReference getReference(Field field) {
return delegate.getReference(field);
}

@Override
public FieldReference getReference(String name) {
return delegate.getReference(name);
}
}
Loading

0 comments on commit c574e5c

Please sign in to comment.