Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[resources] Introduce planner side validation of resources #387

Merged
merged 3 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import ai.langstream.api.model.Connection;
import ai.langstream.api.model.Module;
import ai.langstream.api.model.Pipeline;
import ai.langstream.api.model.Resource;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -84,5 +86,10 @@ default AgentNodeMetadata computeAgentMetadata(
return null;
}

default Map<String, Object> getResourceImplementation(
Resource resource, PluginsRegistry pluginsRegistry) {
return new HashMap<>(resource.configuration());
}

default void close() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,29 @@ public AssetNodeProvider lookupAssetImplementation(
+ clusterRuntime.getClusterType()));
return assetRuntimeProviderProvider.get();
}

public ResourceNodeProvider lookupResourceImplementation(
String type, ComputeClusterRuntime clusterRuntime) {
log.info(
"Looking for an implementation of resource type {} on {}",
type,
clusterRuntime.getClusterType());
ServiceLoader<ResourceNodeProvider> loader = ServiceLoader.load(ResourceNodeProvider.class);
ServiceLoader.Provider<ResourceNodeProvider> runtimeProviderProvider =
loader.stream()
.filter(
p -> {
ResourceNodeProvider nodeProvider = p.get();
return nodeProvider.supports(type, clusterRuntime);
})
.findFirst()
.orElseThrow(
() ->
new RuntimeException(
"No ResourceNodeProvider found for resource type "
+ type
+ " for cluster type "
+ clusterRuntime.getClusterType()));
return runtimeProviderProvider.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.api.runtime;

import ai.langstream.api.model.Module;
import ai.langstream.api.model.Resource;
import java.util.Map;

public interface ResourceNodeProvider {

/**
* Create an Implementation of a Resource.
*
* @param module the module
* @param executionPlan the physical application instance
* @param clusterRuntime the cluster runtime
* @param pluginsRegistry the plugins registry
* @return the Agent
*/
Map<String, Object> createImplementation(
Resource resource,
Module module,
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry);

/**
* Returns the ability of a Resource to be deployed on the give runtimes.
*
* @param type the type of implementation
* @param clusterRuntime the compute cluster runtime
* @return true if this provider can create the implementation
*/
boolean supports(String type, ComputeClusterRuntime clusterRuntime);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import ai.langstream.api.runtime.ComponentType;
import ai.langstream.api.runtime.ComputeClusterRuntime;
import ai.langstream.api.runtime.ExecutionPlan;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.impl.common.AbstractAgentProvider;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -390,7 +391,9 @@ protected void generateSteps(
Map<String, Object> originalConfiguration,
Map<String, Object> configuration,
Application application,
AgentConfiguration agentConfiguration) {
AgentConfiguration agentConfiguration,
ComputeClusterRuntime computeClusterRuntime,
PluginsRegistry pluginsRegistry) {
List<Map<String, Object>> steps = new ArrayList<>();
configuration.put("steps", steps);
Map<String, Object> step = new HashMap<>();
Expand All @@ -403,7 +406,12 @@ protected void generateSteps(

DataSourceConfigurationGenerator dataSourceConfigurationInjector =
(resourceId) ->
generateDataSourceConfiguration(resourceId, application, configuration);
generateDataSourceConfiguration(
resourceId,
application,
configuration,
computeClusterRuntime,
pluginsRegistry);

TopicConfigurationGenerator topicConfigurationGenerator =
(topicName) -> {
Expand All @@ -427,11 +435,15 @@ interface DataSourceConfigurationGenerator {
}

private void generateAIProvidersConfiguration(
Application applicationInstance, Map<String, Object> configuration) {
Application applicationInstance,
Map<String, Object> configuration,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
// let the user force the provider or detect it automatically
String service = (String) configuration.remove("service");
for (Resource resource : applicationInstance.getResources().values()) {
HashMap<String, Object> configurationCopy = new HashMap<>(resource.configuration());
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
switch (resource.type()) {
case "vertex-configuration":
if (service == null || service.equals("vertex")) {
Expand All @@ -456,7 +468,11 @@ private void generateAIProvidersConfiguration(
}

private void generateDataSourceConfiguration(
String resourceId, Application applicationInstance, Map<String, Object> configuration) {
String resourceId,
Application applicationInstance,
Map<String, Object> configuration,
ComputeClusterRuntime computeClusterRuntime,
PluginsRegistry pluginsRegistry) {
Resource resource = applicationInstance.getResources().get(resourceId);
log.info("Generating datasource configuration for {}", resourceId);
if (resource != null) {
Expand All @@ -467,7 +483,9 @@ private void generateDataSourceConfiguration(
if (configuration.containsKey("datasource")) {
throw new IllegalArgumentException("Only one datasource is supported");
}
configuration.put("datasource", new HashMap<>(resource.configuration()));
Map<String, Object> resourceImplementation =
computeClusterRuntime.getResourceImplementation(resource, pluginsRegistry);
configuration.put("datasource", resourceImplementation);
} else {
throw new IllegalArgumentException("Resource " + resourceId + " not found");
}
Expand All @@ -479,20 +497,29 @@ protected Map<String, Object> computeAgentConfiguration(
Module module,
Pipeline pipeline,
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime) {
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
Map<String, Object> originalConfiguration =
super.computeAgentConfiguration(
agentConfiguration, module, pipeline, executionPlan, clusterRuntime);
agentConfiguration,
module,
pipeline,
executionPlan,
clusterRuntime,
pluginsRegistry);
Map<String, Object> configuration = new HashMap<>();

generateAIProvidersConfiguration(executionPlan.getApplication(), configuration);
generateAIProvidersConfiguration(
executionPlan.getApplication(), configuration, clusterRuntime, pluginsRegistry);

generateSteps(
module,
originalConfiguration,
configuration,
executionPlan.getApplication(),
agentConfiguration);
agentConfiguration,
clusterRuntime,
pluginsRegistry);
return configuration;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ protected Map<String, Object> computeAgentConfiguration(
Module module,
Pipeline pipeline,
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime) {
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
return new HashMap<>(agentConfiguration.getConfiguration());
}

Expand All @@ -141,7 +142,12 @@ public AgentNode createImplementation(
ComponentType componentType = getComponentType(agentConfiguration);
Map<String, Object> configuration =
computeAgentConfiguration(
agentConfiguration, module, pipeline, executionPlan, clusterRuntime);
agentConfiguration,
module,
pipeline,
executionPlan,
clusterRuntime,
pluginsRegistry);
// we create the output connection first to make sure that the topic is created
ConnectionImplementation output =
computeOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ public final AssetNode createImplementation(
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
Map<String, Object> asset = planAsset(assetDefinition, executionPlan.getApplication());
Map<String, Object> asset =
planAsset(
assetDefinition,
executionPlan.getApplication(),
clusterRuntime,
pluginsRegistry);
validateAsset(assetDefinition, asset);
return new AssetNode(asset);
}
Expand All @@ -54,7 +59,10 @@ protected abstract void validateAsset(
AssetDefinition assetDefinition, Map<String, Object> asset);

private Map<String, Object> planAsset(
AssetDefinition assetDefinition, Application application) {
AssetDefinition assetDefinition,
Application application,
ComputeClusterRuntime computeClusterRuntime,
PluginsRegistry pluginsRegistry) {

if (!supportedType.contains(assetDefinition.getAssetType())) {
throw new IllegalStateException();
Expand All @@ -81,7 +89,10 @@ private Map<String, Object> planAsset(
key);
Resource resource = resources.get(resourceId);
if (resource != null) {
value = Map.of("configuration", resource.configuration());
Map<String, Object> resourceImplementation =
computeClusterRuntime.getResourceImplementation(
resource, pluginsRegistry);
value = Map.of("configuration", resourceImplementation);
} else {
throw new IllegalArgumentException(
"Resource with name="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.langstream.api.model.Connection;
import ai.langstream.api.model.Module;
import ai.langstream.api.model.Pipeline;
import ai.langstream.api.model.Resource;
import ai.langstream.api.model.TopicDefinition;
import ai.langstream.api.runtime.AgentNode;
import ai.langstream.api.runtime.AgentNodeProvider;
Expand All @@ -32,8 +33,10 @@
import ai.langstream.api.runtime.ExecutionPlan;
import ai.langstream.api.runtime.ExecutionPlanOptimiser;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.api.runtime.ResourceNodeProvider;
import ai.langstream.api.runtime.StreamingClusterRuntime;
import ai.langstream.api.runtime.Topic;
import java.util.HashMap;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -143,6 +146,15 @@ protected void detectAgents(
}
}

@Override
public Map<String, Object> getResourceImplementation(
Resource resource, PluginsRegistry pluginsRegistry) {
ResourceNodeProvider nodeProvider =
pluginsRegistry.lookupResourceImplementation(resource.type(), this);
// TODO: validate resource
return new HashMap<>(resource.configuration());
}

protected AgentNode buildAgent(
Module module,
Pipeline pipeline,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.impl.resources;

import ai.langstream.api.model.Module;
import ai.langstream.api.model.Resource;
import ai.langstream.api.runtime.ComputeClusterRuntime;
import ai.langstream.api.runtime.ExecutionPlan;
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.api.runtime.ResourceNodeProvider;
import java.util.Map;
import java.util.Set;

public class AIProvidersResourceProvider implements ResourceNodeProvider {

private static final Set<String> SUPPORTED_TYPES =
Set.of("open-ai-configuration", "hugging-face-configuration", "vertex-configuration");

@Override
public Map<String, Object> createImplementation(
Resource resource,
Module module,
ExecutionPlan executionPlan,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
Map<String, Object> configuration = resource.configuration();
return resource.configuration();
}

@Override
public boolean supports(String type, ComputeClusterRuntime clusterRuntime) {
return SUPPORTED_TYPES.contains(type);
}
}
Loading