Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-36749][state/forst] Implement rescaling for ForStStateBackend #25676

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ public OperatorSnapshotFutures snapshotState(
boolean isUsingCustomRawKeyedState,
boolean useAsyncState)
throws CheckpointException {
KeyGroupRange keyGroupRange =
null != keyedStateBackend
? keyedStateBackend.getKeyGroupRange()
: KeyGroupRange.EMPTY_KEY_GROUP_RANGE;
KeyGroupRange keyGroupRange = KeyGroupRange.EMPTY_KEY_GROUP_RANGE;
if (keyedStateBackend != null) {
keyGroupRange = keyedStateBackend.getKeyGroupRange();
} else if (asyncKeyedStateBackend != null) {
keyGroupRange = asyncKeyedStateBackend.getKeyGroupRange();
}

OperatorSnapshotFutures snapshotInProgress = new OperatorSnapshotFutures();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
Expand All @@ -44,9 +45,12 @@
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.ConfigurableStateBackend;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.KeyedStateBackendParametersImpl;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.SharedStateRegistry;
import org.apache.flink.runtime.state.SharedStateRegistryImpl;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.TestTaskStateManager;
Expand All @@ -66,6 +70,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.RunnableFuture;

import static java.util.Arrays.asList;
Expand Down Expand Up @@ -136,7 +141,7 @@ protected CheckpointStreamFactory createStreamFactory() throws Exception {
}

protected <K> AsyncKeyedStateBackend<K> createAsyncKeyedBackend(
TypeSerializer<K> keySerializer, int numberOfKeyGroups, Environment env)
TypeSerializer<K> keySerializer, KeyGroupRange keyGroupRange, Environment env)
throws Exception {

env.setCheckpointStorageAccess(getCheckpointStorageAccess());
Expand All @@ -148,8 +153,8 @@ protected <K> AsyncKeyedStateBackend<K> createAsyncKeyedBackend(
new JobID(),
"test_op",
keySerializer,
numberOfKeyGroups,
new KeyGroupRange(0, numberOfKeyGroups - 1),
keyGroupRange.getNumberOfKeyGroups(),
keyGroupRange,
env.getTaskKvStateRegistry(),
TtlTimeProvider.DEFAULT,
getMetricGroup(),
Expand All @@ -165,7 +170,7 @@ protected StateBackend.CustomInitializationMetrics getCustomInitializationMetric

protected <K> AsyncKeyedStateBackend<K> restoreAsyncKeyedBackend(
TypeSerializer<K> keySerializer,
int numberOfKeyGroups,
KeyGroupRange keyGroupRange,
List<KeyedStateHandle> state,
Environment env)
throws Exception {
Expand All @@ -177,8 +182,8 @@ protected <K> AsyncKeyedStateBackend<K> restoreAsyncKeyedBackend(
new JobID(),
"test_op",
keySerializer,
numberOfKeyGroups,
new KeyGroupRange(0, numberOfKeyGroups - 1),
keyGroupRange.getNumberOfKeyGroups(),
keyGroupRange,
Copy link

@davidradl davidradl Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need both parameters - if one can be derived from another?

env.getTaskKvStateRegistry(),
TtlTimeProvider.DEFAULT,
getMetricGroup(),
Expand Down Expand Up @@ -212,7 +217,11 @@ void testAsyncKeyedStateBackendSnapshot() throws Exception {
KeyedStateHandle stateHandle;

try {
backend = createAsyncKeyedBackend(IntSerializer.INSTANCE, jobMaxParallelism, env);
backend =
createAsyncKeyedBackend(
IntSerializer.INSTANCE,
new KeyGroupRange(0, jobMaxParallelism - 1),
env);
aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
Expand Down Expand Up @@ -299,7 +308,7 @@ void testAsyncKeyedStateBackendSnapshot() throws Exception {
backend =
restoreAsyncKeyedBackend(
IntSerializer.INSTANCE,
jobMaxParallelism,
new KeyGroupRange(0, jobMaxParallelism - 1),
Collections.singletonList(stateHandle),
env);
aec =
Expand Down Expand Up @@ -345,6 +354,151 @@ void testAsyncKeyedStateBackendSnapshot() throws Exception {
assertThat(testExceptionHandler.exception).isNull();
}

@TestTemplate
void testAsyncStateBackendScaleUp() throws Exception {
testKeyGroupSnapshotRestore(2, 5, 10);
}

@TestTemplate
void testAsyncStateBackendScaleDown() throws Exception {
testKeyGroupSnapshotRestore(4, 3, 10);
}

private void testKeyGroupSnapshotRestore(
int sourceParallelism, int targetParallelism, int maxParallelism) throws Exception {

int aecBatchSize = 1;
long aecBufferTimeout = 1;
int aecMaxInFlightRecords = 1000;
Random random = new Random();
List<ValueStateDescriptor<Integer>> stateDescriptors = new ArrayList<>(maxParallelism);
List<Integer> keyInKeyGroups = new ArrayList<>(maxParallelism);
List<Integer> expectedValue = new ArrayList<>(maxParallelism);
for (int i = 0; i < maxParallelism; ++i) {
// all states have different name to mock that all the parallelisms of one operator have
// different states.
stateDescriptors.add(
new ValueStateDescriptor<>("state" + i, BasicTypeInfo.INT_TYPE_INFO));
}

CheckpointStreamFactory streamFactory = createStreamFactory();
SharedStateRegistry sharedStateRegistry = new SharedStateRegistryImpl();
AsyncExecutionController<Integer> aec;

TestAsyncFrameworkExceptionHandler testExceptionHandler =
new TestAsyncFrameworkExceptionHandler();

List<KeyedStateHandle> snapshots = new ArrayList<>(sourceParallelism);
for (int i = 0; i < sourceParallelism; ++i) {
KeyGroupRange range =
KeyGroupRange.of(
maxParallelism * i / sourceParallelism,
maxParallelism * (i + 1) / sourceParallelism - 1);
AsyncKeyedStateBackend<Integer> backend =
createAsyncKeyedBackend(IntSerializer.INSTANCE, range, env);
aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
testExceptionHandler,
backend.createStateExecutor(),
maxParallelism,
aecBatchSize,
aecBufferTimeout,
aecMaxInFlightRecords,
null);
backend.setup(aec);

try {
for (int j = range.getStartKeyGroup(); j <= range.getEndKeyGroup(); ++j) {
ValueState<Integer> state =
backend.createState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescriptors.get(j));
int keyInKeyGroup =
getKeyInKeyGroup(random, maxParallelism, KeyGroupRange.of(j, j));
RecordContext recordContext = aec.buildContext(keyInKeyGroup, keyInKeyGroup);
;
recordContext.retain();
aec.setCurrentContext(recordContext);
keyInKeyGroups.add(keyInKeyGroup);
state.update(keyInKeyGroup);
expectedValue.add(keyInKeyGroup);
recordContext.release();
}

// snapshot
snapshots.add(
runSnapshot(
backend.snapshot(
0,
0,
streamFactory,
CheckpointOptions.forCheckpointWithDefaultLocation()),
sharedStateRegistry));
} finally {
IOUtils.closeQuietly(backend);
backend.dispose();
}
}

// redistribute the stateHandle
List<KeyGroupRange> keyGroupRangesRestore = new ArrayList<>();
for (int i = 0; i < targetParallelism; ++i) {
keyGroupRangesRestore.add(
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
maxParallelism, targetParallelism, i));
}
List<List<KeyedStateHandle>> keyGroupStatesAfterDistribute =
new ArrayList<>(targetParallelism);
for (int i = 0; i < targetParallelism; ++i) {
List<KeyedStateHandle> keyedStateHandles = new ArrayList<>();
StateAssignmentOperation.extractIntersectingState(
snapshots, keyGroupRangesRestore.get(i), keyedStateHandles);
keyGroupStatesAfterDistribute.add(keyedStateHandles);
}
// restore and verify
for (int i = 0; i < targetParallelism; ++i) {
AsyncKeyedStateBackend<Integer> backend =
restoreAsyncKeyedBackend(
IntSerializer.INSTANCE,
keyGroupRangesRestore.get(i),
keyGroupStatesAfterDistribute.get(i),
env);
aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
testExceptionHandler,
backend.createStateExecutor(),
maxParallelism,
aecBatchSize,
aecBufferTimeout,
aecMaxInFlightRecords,
null);
backend.setup(aec);

try {
KeyGroupRange range = keyGroupRangesRestore.get(i);
for (int j = range.getStartKeyGroup(); j <= range.getEndKeyGroup(); ++j) {
ValueState<Integer> state =
backend.createState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescriptors.get(j));
RecordContext recordContext =
aec.buildContext(keyInKeyGroups.get(j), keyInKeyGroups.get(j));
recordContext.retain();
aec.setCurrentContext(recordContext);
assertThat(state.value()).isEqualTo(expectedValue.get(j));
recordContext.release();
}
} finally {
IOUtils.closeQuietly(backend);
backend.dispose();
}
}
}

@TestTemplate
void testKeyGroupedInternalPriorityQueue() throws Exception {
testKeyGroupedInternalPriorityQueue(false);
Expand All @@ -358,7 +512,7 @@ void testKeyGroupedInternalPriorityQueueAddAll() throws Exception {
void testKeyGroupedInternalPriorityQueue(boolean addAll) throws Exception {
String fieldName = "key-grouped-priority-queue";
AsyncKeyedStateBackend<Integer> backend =
createAsyncKeyedBackend(IntSerializer.INSTANCE, 128, env);
createAsyncKeyedBackend(IntSerializer.INSTANCE, new KeyGroupRange(0, 127), env);
try {
KeyGroupedInternalPriorityQueue<TestType> priorityQueue =
backend.create(fieldName, new TestType.V1TestTypeSerializer());
Expand Down Expand Up @@ -414,7 +568,7 @@ void testValueStateWorkWithTtl() throws Exception {
TestAsyncFrameworkExceptionHandler testExceptionHandler =
new TestAsyncFrameworkExceptionHandler();
AsyncKeyedStateBackend<Long> backend =
createAsyncKeyedBackend(LongSerializer.INSTANCE, 128, env);
createAsyncKeyedBackend(LongSerializer.INSTANCE, new KeyGroupRange(0, 127), env);
AsyncExecutionController<Long> aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
Expand Down Expand Up @@ -468,6 +622,34 @@ void testValueStateWorkWithTtl() throws Exception {
}
}

/** Returns an Integer key in specified keyGroupRange. */
private int getKeyInKeyGroup(Random random, int maxParallelism, KeyGroupRange keyGroupRange) {
int keyInKG = random.nextInt();
int kg = KeyGroupRangeAssignment.assignToKeyGroup(keyInKG, maxParallelism);
while (!keyGroupRange.contains(kg)) {
keyInKG = random.nextInt();
kg = KeyGroupRangeAssignment.assignToKeyGroup(keyInKG, maxParallelism);
}
return keyInKG;
}

private static KeyedStateHandle runSnapshot(
RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunnableFuture,
SharedStateRegistry sharedStateRegistry)
throws Exception {

if (!snapshotRunnableFuture.isDone()) {
snapshotRunnableFuture.run();
}

SnapshotResult<KeyedStateHandle> snapshotResult = snapshotRunnableFuture.get();
KeyedStateHandle jobManagerOwnedSnapshot = snapshotResult.getJobManagerOwnedSnapshot();
if (jobManagerOwnedSnapshot != null) {
jobManagerOwnedSnapshot.registerSharedStates(sharedStateRegistry, 0L);
}
return jobManagerOwnedSnapshot;
}

static class TestAsyncFrameworkExceptionHandler
implements StateFutureImpl.AsyncFrameworkExceptionHandler {
String message = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,32 @@ public class ForStConfigurableOptions implements Serializable {
+ " and re-written to the same level as they were before. It makes sure a file goes through"
+ " compaction filters periodically. 0 means turning off periodic compaction.The default value is '30days'.");

public static final ConfigOption<Double> RESTORE_OVERLAP_FRACTION_THRESHOLD =
key("state.backend.forst.restore-overlap-fraction-threshold")
.doubleType()
.defaultValue(0.0)
.withDescription(
"The threshold of overlap fraction between the handle's key-group range and target key-group range. "
+ "When restore base DB, only the handle which overlap fraction greater than or equal to threshold "
+ "has a chance to be an initial handle. "
+ "The default value is 0.0, there is always a handle will be selected for initialization. ");

public static final ConfigOption<Boolean> USE_INGEST_DB_RESTORE_MODE =
key("state.backend.forst.use-ingest-db-restore-mode")
.booleanType()
.defaultValue(Boolean.FALSE)
.withDescription(
"A recovery mode that directly clips and ingests multiple DBs during state recovery if the keys"
+ " in the SST files does not exceed the declared key-group range.");

public static final ConfigOption<Boolean> USE_DELETE_FILES_IN_RANGE_DURING_RESCALING =
key("state.backend.forst.rescaling.use-delete-files-in-range")
.booleanType()
.defaultValue(Boolean.FALSE)
.withDescription(
"If true, during rescaling, the deleteFilesInRange API will be invoked "
+ "to clean up the useless files so that local disk space can be reclaimed more promptly.");

static final ConfigOption<?>[] CANDIDATE_CONFIGS =
new ConfigOption<?>[] {
// cache
Expand Down Expand Up @@ -361,6 +387,9 @@ public class ForStConfigurableOptions implements Serializable {
USE_BLOOM_FILTER,
BLOOM_FILTER_BITS_PER_KEY,
BLOOM_FILTER_BLOCK_BASED_MODE,
RESTORE_OVERLAP_FRACTION_THRESHOLD,
USE_INGEST_DB_RESTORE_MODE,
USE_DELETE_FILES_IN_RANGE_DURING_RESCALING
};

private static final Set<ConfigOption<?>> POSITIVE_INT_CONFIG_SET =
Expand Down
Loading