Skip to content

Commit

Permalink
Make MultimapCodec async-compatible.
Browse files Browse the repository at this point in the history
The previous implementation contained unsound casts. LinkedHashMultimap is
a declarable type that cannot be assigned from ImmutableSetMultimap which
would have led to a runtime crash. So the codecs are broken out and registered
individually.

Minor ArrayProcessor changes.
* Use a static import for array offset constants.
* Add a deserializeObjectArrayFully method.

PiperOrigin-RevId: 595569849
Change-Id: I8add48dacd13d4e08159c48e4145632055af2e15
  • Loading branch information
aoeui authored and copybara-github committed Jan 4, 2024
1 parent dc3a359 commit 1f2d576
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
import static com.google.devtools.build.lib.skyframe.serialization.CodecHelpers.writeChar;
import static com.google.devtools.build.lib.skyframe.serialization.CodecHelpers.writeShort;
import static com.google.devtools.build.lib.unsafe.UnsafeProvider.unsafe;
import static sun.misc.Unsafe.ARRAY_OBJECT_BASE_OFFSET;
import static sun.misc.Unsafe.ARRAY_OBJECT_INDEX_SCALE;

import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.IOException;
import java.lang.reflect.Array;
import sun.misc.Unsafe;

/**
* Stateless class that encodes and decodes arrays that may be multi-dimensional.
Expand Down Expand Up @@ -161,7 +162,11 @@ public final void deserialize(
unsafe().putObject(obj, offset, arr);
for (int i = 0; i < length; ++i) {
deserialize(
context, codedIn, componentType, arr, OBJECT_ARR_OFFSET + OBJECT_ARR_SCALE * i);
context,
codedIn,
componentType,
arr,
ARRAY_OBJECT_BASE_OFFSET + ARRAY_OBJECT_INDEX_SCALE * i);
}
return;
}
Expand Down Expand Up @@ -418,7 +423,11 @@ public void deserialize(
if (componentType.isArray()) {
for (int i = 0; i < length; ++i) {
deserialize(
context, codedIn, componentType, arr, OBJECT_ARR_OFFSET + OBJECT_ARR_SCALE * i);
context,
codedIn,
componentType,
arr,
ARRAY_OBJECT_BASE_OFFSET + ARRAY_OBJECT_INDEX_SCALE * i);
}
return;
}
Expand All @@ -441,10 +450,16 @@ public static void deserializeObjectArray(
AsyncDeserializationContext context, CodedInputStream codedIn, Object arr, int length)
throws IOException, SerializationException {
for (int i = 0; i < length; ++i) {
context.deserialize(codedIn, arr, OBJECT_ARR_OFFSET + OBJECT_ARR_SCALE * i);
context.deserialize(codedIn, arr, ARRAY_OBJECT_BASE_OFFSET + ARRAY_OBJECT_INDEX_SCALE * i);
}
}

private static final int OBJECT_ARR_OFFSET = Unsafe.ARRAY_OBJECT_BASE_OFFSET;
private static final int OBJECT_ARR_SCALE = Unsafe.ARRAY_OBJECT_INDEX_SCALE;
public static void deserializeObjectArrayFully(
AsyncDeserializationContext context, CodedInputStream codedIn, Object arr, int length)
throws IOException, SerializationException {
for (int i = 0; i < length; ++i) {
context.deserializeFully(
codedIn, arr, ARRAY_OBJECT_BASE_OFFSET + ARRAY_OBJECT_INDEX_SCALE * i);
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Copyright 2018 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.skyframe.serialization;

import static com.google.devtools.build.lib.skyframe.serialization.ArrayProcessor.deserializeObjectArray;
import static com.google.devtools.build.lib.skyframe.serialization.ArrayProcessor.deserializeObjectArrayFully;
import static sun.misc.Unsafe.ARRAY_OBJECT_BASE_OFFSET;
import static sun.misc.Unsafe.ARRAY_OBJECT_INDEX_SCALE;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import com.google.errorprone.annotations.Keep;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.function.Supplier;

/**
* Codecs for {@link Multimap}. Handles {@link ImmutableListMultimap}, {@link ImmutableSetMultimap}
* and {@link LinkedHashMultimap}.
*/
@SuppressWarnings({"unchecked", "rawtypes", "NonApiType"})
public final class MultimapCodecs {
@Keep // used reflectively
private static class ImmutableListMultimapCodec
extends DeferredObjectCodec<ImmutableListMultimap> {
@Override
public Class<ImmutableListMultimap> getEncodedClass() {
return ImmutableListMultimap.class;
}

@Override
public void serialize(
SerializationContext context, ImmutableListMultimap obj, CodedOutputStream codedOut)
throws SerializationException, IOException {
serializeMultimap(context, obj, codedOut);
}

@Override
public Supplier<ImmutableListMultimap> deserializeDeferred(
AsyncDeserializationContext context, CodedInputStream codedIn)
throws SerializationException, IOException {
int size = codedIn.readInt32();
if (size == 0) {
return ImmutableListMultimap::of;
}

ImmutableListMultimapBuffer buffer = new ImmutableListMultimapBuffer(size);
for (int i = 0; i < size; i++) {
context.deserializeFully(
codedIn, buffer.keys, ARRAY_OBJECT_BASE_OFFSET + i * ARRAY_OBJECT_INDEX_SCALE);
int valuesCount = codedIn.readInt32();
Object[] values = new Object[valuesCount];
buffer.values[i] = values;
// The builder merely collects these references in an ArrayList, so (unlike the set-type
// multimaps) these values do not need to be fully deserialized.
deserializeObjectArray(context, codedIn, buffer.values[i], valuesCount);
}
return buffer;
}
}

@Keep // used reflectively
private static class ImmutableSetMultimapCodec extends DeferredObjectCodec<ImmutableSetMultimap> {
@Override
public Class<ImmutableSetMultimap> getEncodedClass() {
return ImmutableSetMultimap.class;
}

@Override
public void serialize(
SerializationContext context, ImmutableSetMultimap obj, CodedOutputStream codedOut)
throws SerializationException, IOException {
serializeMultimap(context, obj, codedOut);
}

@Override
public Supplier<ImmutableSetMultimap> deserializeDeferred(
AsyncDeserializationContext context, CodedInputStream codedIn)
throws SerializationException, IOException {
int size = codedIn.readInt32();
if (size == 0) {
return ImmutableSetMultimap::of;
}

ImmutableSetMultimapBuffer buffer = new ImmutableSetMultimapBuffer(size);
deserializeSetMultimap(context, codedIn, buffer);
return buffer;
}
}

@Keep // used reflectively
private static class LinkedHashMultimapCodec extends DeferredObjectCodec<LinkedHashMultimap> {
@Override
public Class<LinkedHashMultimap> getEncodedClass() {
return LinkedHashMultimap.class;
}

@Override
public void serialize(
SerializationContext context, LinkedHashMultimap obj, CodedOutputStream codedOut)
throws SerializationException, IOException {
serializeMultimap(context, obj, codedOut);
}

@Override
public Supplier<LinkedHashMultimap> deserializeDeferred(
AsyncDeserializationContext context, CodedInputStream codedIn)
throws SerializationException, IOException {
int size = codedIn.readInt32();
if (size == 0) {
return LinkedHashMultimap::create;
}

LinkedHashMultimapBuffer buffer = new LinkedHashMultimapBuffer(size);
deserializeSetMultimap(context, codedIn, buffer);
return buffer;
}
}

private static void serializeMultimap(
SerializationContext context, Multimap obj, CodedOutputStream codedOut)
throws SerializationException, IOException {
Map map = obj.asMap();
codedOut.writeInt32NoTag(map.size());
for (Object next : map.entrySet()) {
Map.Entry entry = (Map.Entry) next;

context.serialize(entry.getKey(), codedOut);

Collection values = (Collection) entry.getValue();
codedOut.writeInt32NoTag(values.size());
for (Object value : values) {
context.serialize(value, codedOut);
}
}
}

/** Takes care to fully deserialize all keys and values as they will be used in sets. */
private static void deserializeSetMultimap(
AsyncDeserializationContext context, CodedInputStream codedIn, MultimapBuffer buffer)
throws SerializationException, IOException {
for (int i = 0; i < buffer.size(); i++) {
context.deserializeFully(
codedIn, buffer.keys, ARRAY_OBJECT_BASE_OFFSET + i * ARRAY_OBJECT_INDEX_SCALE);

int valuesCount = codedIn.readInt32();
Object[] values = new Object[valuesCount];
buffer.values[i] = values;
// The builder uses a set to collect the values so they must be complete.
deserializeObjectArrayFully(context, codedIn, values, valuesCount);
}
}

private static class MultimapBuffer {
final Object[] keys;
final Object[][] values;

private MultimapBuffer(int size) {
this.keys = new Object[size];
this.values = new Object[size][];
}

int size() {
return keys.length;
}
}

private static class ImmutableListMultimapBuffer extends MultimapBuffer
implements Supplier<ImmutableListMultimap> {
private ImmutableListMultimapBuffer(int size) {
super(size);
}

@Override
public ImmutableListMultimap get() {
ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
for (int i = 0; i < size(); i++) {
builder.putAll(keys[i], values[i]);
}
return builder.build();
}
}

private static class ImmutableSetMultimapBuffer extends MultimapBuffer
implements Supplier<ImmutableSetMultimap> {
private ImmutableSetMultimapBuffer(int size) {
super(size);
}

@Override
public ImmutableSetMultimap get() {
ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder();
for (int i = 0; i < size(); i++) {
builder.putAll(keys[i], values[i]);
}
return builder.build();
}
}

private static class LinkedHashMultimapBuffer extends MultimapBuffer
implements Supplier<LinkedHashMultimap> {
private LinkedHashMultimapBuffer(int size) {
super(size);
}

@Override
public LinkedHashMultimap get() {
int totalValues = 0;
for (int i = 0; i < size(); i++) {
totalValues += values[i].length;
}
LinkedHashMultimap result = LinkedHashMultimap.create(size(), totalValues / size());
for (int i = 0; i < size(); i++) {
result.putAll(keys[i], Arrays.asList(values[i]));
}
return result;
}
}

private MultimapCodecs() {}
}

0 comments on commit 1f2d576

Please sign in to comment.