Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FieldValue.numericAdd() #105

Merged
merged 19 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.firestore;

import static com.google.firebase.firestore.testutil.IntegrationTestUtil.testDocument;
import static com.google.firebase.firestore.testutil.IntegrationTestUtil.waitFor;
import static com.google.firebase.firestore.testutil.TestUtil.map;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;

import com.google.android.gms.tasks.Tasks;
import com.google.firebase.firestore.testutil.EventAccumulator;
import com.google.firebase.firestore.testutil.IntegrationTestUtil;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

@Ignore("Not yet available in production")
mikelehen marked this conversation as resolved.
Show resolved Hide resolved
public class NumericTransformsTest {
private static final double DOUBLE_EPSILON = 0.000001;

// A document reference to read and write to.
private DocumentReference docRef;

// Accumulator used to capture events during the test.
private EventAccumulator<DocumentSnapshot> accumulator;

// Listener registration for a listener maintained during the course of the test.
private ListenerRegistration listenerRegistration;

@Before
public void setUp() {
docRef = testDocument();
accumulator = new EventAccumulator<>();
listenerRegistration =
docRef.addSnapshotListener(MetadataChanges.INCLUDE, accumulator.listener());

// Wait for initial null snapshot to avoid potential races.
DocumentSnapshot initialSnapshot = accumulator.await();
assertFalse(initialSnapshot.exists());
}

@After
public void tearDown() {
listenerRegistration.remove();
IntegrationTestUtil.tearDown();
}

/** Writes some initialData and consumes the events generated. */
private void writeInitialData(Map<String, Object> initialData) {
waitFor(docRef.set(initialData));
accumulator.awaitRemoteEvent();
}

private void expectLocalAndRemoteValue(double expectedSum) {
DocumentSnapshot snap = accumulator.awaitLocalEvent();
assertEquals(expectedSum, snap.getDouble("sum"), DOUBLE_EPSILON);
snap = accumulator.awaitRemoteEvent();
assertEquals(expectedSum, snap.getDouble("sum"), DOUBLE_EPSILON);
}

private void expectLocalAndRemoteValue(long expectedSum) {
DocumentSnapshot snap = accumulator.awaitLocalEvent();
assertEquals(expectedSum, (long) snap.getLong("sum"));
snap = accumulator.awaitRemoteEvent();
assertEquals(expectedSum, (long) snap.getLong("sum"));
}

@Test
public void createDocumentWithIncrement() {
waitFor(docRef.set(map("sum", FieldValue.numericAdd(1337))));
expectLocalAndRemoteValue(1337L);
}
mikelehen marked this conversation as resolved.
Show resolved Hide resolved

@Test
public void integerIncrementExistingInteger() {
writeInitialData(map("sum", 1337L));
waitFor(docRef.update("sum", FieldValue.numericAdd(1)));
expectLocalAndRemoteValue(1338L);
}

@Test
public void doubleIncrementWithExistingDouble() {
mikelehen marked this conversation as resolved.
Show resolved Hide resolved
writeInitialData(map("sum", 13.37D));
waitFor(docRef.update("sum", FieldValue.numericAdd(0.1)));
expectLocalAndRemoteValue(13.47D);
}

@Test
public void integerIncrementExistingDouble() {
writeInitialData(map("sum", 13.37D));
waitFor(docRef.update("sum", FieldValue.numericAdd(1)));
expectLocalAndRemoteValue(14.37D);
}

@Test
public void doubleIncrementExistingInteger() {
writeInitialData(map("sum", 1337L));
waitFor(docRef.update("sum", FieldValue.numericAdd(0.1)));
expectLocalAndRemoteValue(1337.1D);
}

@Test
public void integerIncrementExistingString() {
writeInitialData(map("sum", "overwrite"));
waitFor(docRef.update("sum", FieldValue.numericAdd(1337)));
expectLocalAndRemoteValue(1337L);
}

@Test
public void doubleIncrementExistingString() {
writeInitialData(map("sum", "overwrite"));
waitFor(docRef.update("sum", FieldValue.numericAdd(13.37)));
expectLocalAndRemoteValue(13.37D);
}

@Test
public void multipleDoubleIncrements() throws ExecutionException, InterruptedException {
writeInitialData(map("sum", 0.0D));

Tasks.await(docRef.getFirestore().disableNetwork());

docRef.update("sum", FieldValue.numericAdd(0.1D));
docRef.update("sum", FieldValue.numericAdd(0.01D));
docRef.update("sum", FieldValue.numericAdd(0.001D));

DocumentSnapshot snap = accumulator.awaitLocalEvent();
assertEquals(0.1D, snap.getDouble("sum"), DOUBLE_EPSILON);
snap = accumulator.awaitLocalEvent();
assertEquals(0.11D, snap.getDouble("sum"), DOUBLE_EPSILON);
snap = accumulator.awaitLocalEvent();
assertEquals(0.111D, snap.getDouble("sum"), DOUBLE_EPSILON);

Tasks.await(docRef.getFirestore().enableNetwork());

snap = accumulator.awaitRemoteEvent();
assertEquals(0.111D, snap.getDouble("sum"), DOUBLE_EPSILON);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static junit.framework.Assert.fail;

import android.content.Context;
import android.net.SSLCertificateSocketFactory;
import android.support.test.InstrumentationRegistry;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
Expand All @@ -38,10 +39,12 @@
import com.google.firebase.firestore.local.Persistence;
import com.google.firebase.firestore.local.SQLitePersistence;
import com.google.firebase.firestore.model.DatabaseId;
import com.google.firebase.firestore.remote.Datastore;
import com.google.firebase.firestore.testutil.provider.FirestoreProvider;
import com.google.firebase.firestore.util.AsyncQueue;
import com.google.firebase.firestore.util.Logger;
import com.google.firebase.firestore.util.Logger.Level;
import io.grpc.okhttp.OkHttpChannelBuilder;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -56,6 +59,13 @@
/** A set of helper methods for tests */
public class IntegrationTestUtil {

// Whether the integration tests should run against a local Firestore emulator instead of the
// Production environment. Note that the Android Emulator treats "10.0.2.2" as its host machine.
private static final boolean CONNECT_TO_EMULATOR = false;

private static final String EMULATOR_HOST = "10.0.2.2";
private static final int EMULATOR_PORT = 8081;
mikelehen marked this conversation as resolved.
Show resolved Hide resolved

// Alternate project ID for creating "bad" references. Doesn't actually need to work.
public static final String BAD_PROJECT_ID = "test-project-2";

Expand All @@ -80,11 +90,19 @@ public static FirestoreProvider provider() {
}

public static DatabaseInfo testEnvDatabaseInfo() {
return new DatabaseInfo(
DatabaseId.forProject(provider.projectId()),
"test-persistenceKey",
provider.firestoreHost(),
/*sslEnabled=*/ true);
if (CONNECT_TO_EMULATOR) {
return new DatabaseInfo(
DatabaseId.forProject(provider.projectId()),
"test-persistenceKey",
String.format("%s:%d", EMULATOR_HOST, EMULATOR_PORT),
/*sslEnabled=*/ true);
} else {
return new DatabaseInfo(
DatabaseId.forProject(provider.projectId()),
"test-persistenceKey",
provider.firestoreHost(),
/*sslEnabled=*/ true);
}
}

public static FirebaseFirestoreSettings newTestSettings() {
Expand All @@ -93,11 +111,33 @@ public static FirebaseFirestoreSettings newTestSettings() {

public static FirebaseFirestoreSettings newTestSettingsWithSnapshotTimestampsEnabled(
boolean enabled) {
return new FirebaseFirestoreSettings.Builder()
.setHost(provider.firestoreHost())
.setPersistenceEnabled(true)
.setTimestampsInSnapshotsEnabled(enabled)
.build();
FirebaseFirestoreSettings.Builder settings = new FirebaseFirestoreSettings.Builder();

if (CONNECT_TO_EMULATOR) {
settings.setHost(String.format("%s:%d", EMULATOR_HOST, EMULATOR_PORT));

// Disable SSL and hostname verification
mikelehen marked this conversation as resolved.
Show resolved Hide resolved
OkHttpChannelBuilder channelBuilder =
new OkHttpChannelBuilder(EMULATOR_HOST, EMULATOR_PORT) {
@Override
protected String checkAuthority(String authority) {
return authority;
}
};
channelBuilder.hostnameVerifier((hostname, session) -> true);
SSLCertificateSocketFactory insecureFactory =
(SSLCertificateSocketFactory) SSLCertificateSocketFactory.getInsecure(0, null);
channelBuilder.sslSocketFactory(insecureFactory);

Datastore.overrideChannelBuilder(() -> channelBuilder);
} else {
settings.setHost(provider.firestoreHost());
}

settings.setPersistenceEnabled(true);
settings.setTimestampsInSnapshotsEnabled(enabled);

return settings.build();
}

/** Initializes a new Firestore instance that uses the default project. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ String getMethodName() {
}

List<Object> getElements() {
return this.elements;
return elements;
}
}

Expand All @@ -79,7 +79,25 @@ String getMethodName() {
}

List<Object> getElements() {
return this.elements;
return elements;
}
}

/* FieldValue class for numericAdd() transforms. */
static class NumericAddFieldValue extends FieldValue {
private final Number operand;

NumericAddFieldValue(Number operand) {
this.operand = operand;
}

@Override
String getMethodName() {
return "FieldValue.numericAdd";
}

Number getOperand() {
return operand;
}
}

Expand Down Expand Up @@ -134,4 +152,39 @@ public static FieldValue arrayUnion(@NonNull Object... elements) {
public static FieldValue arrayRemove(@NonNull Object... elements) {
return new ArrayRemoveFieldValue(Arrays.asList(elements));
}

/**
* Returns a special value that can be used with set() or update() that tells the server to add
* the given value to the field's current value.
*
* <p>If the current field value is an integer, possible integer overflows are resolved to
* Long.MAX_VALUE or Long.MIN_VALUE respectively. If the current field value is a double, both
mikelehen marked this conversation as resolved.
Show resolved Hide resolved
* values will be interpreted as doubles and the arithmetic will follow IEEE 754 semantics.
*
* <p>If field is not an integer or double, or if the field does not yet exist, the transformation
* will set the field to the given value.
*
* @return The FieldValue sentinel for use in a call to set() or update().
*/
@NonNull
@PublicApi
public static FieldValue numericAdd(long l) {
return new NumericAddFieldValue(l);
}

/**
* Returns a special value that can be used with set() or update() that tells the server to add
* the given value to the field's current value.
*
* <p>If the current value is an integer or a double, both the current and the given value will be
* interpreted as doubles and all arithmetic will follow IEEE 754 semantics. Otherwise, the
* transformation will set the field to the given value.
*
* @return The FieldValue sentinel for use in a call to set() or update().
*/
@NonNull
@PublicApi
public static FieldValue numericAdd(double l) {
return new NumericAddFieldValue(l);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.firebase.firestore.model.FieldPath;
import com.google.firebase.firestore.model.mutation.ArrayTransformOperation;
import com.google.firebase.firestore.model.mutation.FieldMask;
import com.google.firebase.firestore.model.mutation.NumericAddTransformOperation;
import com.google.firebase.firestore.model.mutation.ServerTimestampOperation;
import com.google.firebase.firestore.model.value.ArrayValue;
import com.google.firebase.firestore.model.value.BlobValue;
Expand All @@ -40,6 +41,7 @@
import com.google.firebase.firestore.model.value.GeoPointValue;
import com.google.firebase.firestore.model.value.IntegerValue;
import com.google.firebase.firestore.model.value.NullValue;
import com.google.firebase.firestore.model.value.NumberValue;
import com.google.firebase.firestore.model.value.ObjectValue;
import com.google.firebase.firestore.model.value.ReferenceValue;
import com.google.firebase.firestore.model.value.StringValue;
Expand Down Expand Up @@ -349,6 +351,13 @@ private void parseSentinelFieldValue(
ArrayTransformOperation arrayRemove = new ArrayTransformOperation.Remove(parsedElements);
context.addToFieldTransforms(context.getPath(), arrayRemove);

} else if (value instanceof com.google.firebase.firestore.FieldValue.NumericAddFieldValue) {
com.google.firebase.firestore.FieldValue.NumericAddFieldValue numericAddFieldValue =
(com.google.firebase.firestore.FieldValue.NumericAddFieldValue) value;
NumberValue operand = (NumberValue) parseQueryValue(numericAddFieldValue.getOperand());
NumericAddTransformOperation numericAdd = new NumericAddTransformOperation(operand);
context.addToFieldTransforms(context.getPath(), numericAdd);

} else {
throw Assert.fail("Unknown FieldValue type: %s", Util.typeName(value));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ com.google.firebase.firestore.proto.WriteBatch encodeMutationBatch(MutationBatch

result.setBatchId(batch.getBatchId());
result.setLocalWriteTime(rpcSerializer.encodeTimestamp(batch.getLocalWriteTime()));
for (Mutation mutation : batch.getBaseMutations()) {
result.addBaseWrites(rpcSerializer.encodeMutation(mutation));
}
for (Mutation mutation : batch.getMutations()) {
result.addWrites(rpcSerializer.encodeMutation(mutation));
}
Expand All @@ -171,13 +174,17 @@ MutationBatch decodeMutationBatch(com.google.firebase.firestore.proto.WriteBatch
int batchId = batch.getBatchId();
Timestamp localWriteTime = rpcSerializer.decodeTimestamp(batch.getLocalWriteTime());

int count = batch.getWritesCount();
List<Mutation> mutations = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
int baseMutationsCount = batch.getBaseWritesCount();
List<Mutation> baseMutations = new ArrayList<>(baseMutationsCount);
for (int i = 0; i < baseMutationsCount; i++) {
baseMutations.add(rpcSerializer.decodeMutation(batch.getBaseWrites(i)));
}
int mutationsCount = batch.getWritesCount();
List<Mutation> mutations = new ArrayList<>(mutationsCount);
for (int i = 0; i < mutationsCount; i++) {
mutations.add(rpcSerializer.decodeMutation(batch.getWrites(i)));
}

return new MutationBatch(batchId, localWriteTime, mutations);
return new MutationBatch(batchId, localWriteTime, baseMutations, mutations);
}

com.google.firebase.firestore.proto.Target encodeQueryData(QueryData queryData) {
Expand Down
Loading