Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added relationships APIs to V3. Added these generic APIs to V3 swagger doc. #10939

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
import io.swagger.v3.oas.annotations.OpenAPIDefinition;
import io.swagger.v3.oas.annotations.info.Info;
import io.swagger.v3.oas.annotations.servers.Server;
import io.swagger.v3.oas.models.Components;
import io.swagger.v3.oas.models.OpenAPI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springdoc.core.models.GroupedOpenApi;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -38,8 +43,6 @@ public class SpringWebConfig implements WebMvcConfigurer {
private static final Set<String> V1_PACKAGES = Set.of("io.datahubproject.openapi.v1");
private static final Set<String> V2_PACKAGES = Set.of("io.datahubproject.openapi.v2");
private static final Set<String> V3_PACKAGES = Set.of("io.datahubproject.openapi.v3");
private static final Set<String> SCHEMA_REGISTRY_PACKAGES =
Set.of("io.datahubproject.openapi.schema.registry");

private static final Set<String> OPENLINEAGE_PACKAGES =
Set.of("io.datahubproject.openapi.openlineage");
Expand Down Expand Up @@ -74,14 +77,31 @@ public void addFormatters(FormatterRegistry registry) {
public GroupedOpenApi v3OpenApiGroup(final EntityRegistry entityRegistry) {
return GroupedOpenApi.builder()
.group("10-openapi-v3")
.displayName("DataHub Entities v3 (OpenAPI)")
.displayName("DataHub v3 (OpenAPI)")
.addOpenApiCustomizer(
openApi -> {
OpenAPI v3OpenApi = OpenAPIV3Generator.generateOpenApiSpec(entityRegistry);
openApi.setInfo(v3OpenApi.getInfo());
openApi.setTags(Collections.emptyList());
openApi.setPaths(v3OpenApi.getPaths());
openApi.setComponents(v3OpenApi.getComponents());
openApi.getPaths().putAll(v3OpenApi.getPaths());
// Merge components. Swagger does not provide append method to add components.
final Components components = new Components();
final Components oComponents = openApi.getComponents();
final Components v3Components = v3OpenApi.getComponents();
components
.callbacks(concat(oComponents::getCallbacks, v3Components::getCallbacks))
.examples(concat(oComponents::getExamples, v3Components::getExamples))
.extensions(concat(oComponents::getExtensions, v3Components::getExtensions))
.headers(concat(oComponents::getHeaders, v3Components::getHeaders))
.links(concat(oComponents::getLinks, v3Components::getLinks))
.parameters(concat(oComponents::getParameters, v3Components::getParameters))
.requestBodies(
concat(oComponents::getRequestBodies, v3Components::getRequestBodies))
.responses(concat(oComponents::getResponses, v3Components::getResponses))
.schemas(concat(oComponents::getSchemas, v3Components::getSchemas))
.securitySchemes(
concat(oComponents::getSecuritySchemes, v3Components::getSecuritySchemes));
openApi.setComponents(components);
})
.packagesToScan(V3_PACKAGES.toArray(String[]::new))
.build();
Expand Down Expand Up @@ -122,4 +142,14 @@ public GroupedOpenApi openlineageOpenApiGroup() {
.packagesToScan(OPENLINEAGE_PACKAGES.toArray(String[]::new))
.build();
}

/** Concatenates two maps. */
private <K, V> Map<K, V> concat(Supplier<Map<K, V>> a, Supplier<Map<K, V>> b) {
return a.get() == null
? b.get()
: b.get() == null
? a.get()
: Stream.concat(a.get().entrySet().stream(), b.get().entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package io.datahubproject.openapi.controller;

import static com.linkedin.metadata.authorization.ApiGroup.RELATIONSHIP;
import static com.linkedin.metadata.authorization.ApiOperation.READ;

import com.datahub.authentication.Authentication;
import com.datahub.authentication.AuthenticationContext;
import com.datahub.authorization.AuthUtil;
import com.datahub.authorization.AuthorizerChain;
import com.linkedin.common.urn.Urn;
import com.linkedin.common.urn.UrnUtils;
import com.linkedin.metadata.aspect.models.graph.Edge;
import com.linkedin.metadata.aspect.models.graph.RelatedEntities;
import com.linkedin.metadata.aspect.models.graph.RelatedEntitiesScrollResult;
import com.linkedin.metadata.graph.elastic.ElasticSearchGraphService;
import com.linkedin.metadata.models.registry.EntityRegistry;
import com.linkedin.metadata.query.filter.RelationshipDirection;
import com.linkedin.metadata.query.filter.RelationshipFilter;
import com.linkedin.metadata.search.utils.QueryUtils;
import io.datahubproject.openapi.exception.UnauthorizedException;
import io.datahubproject.openapi.models.GenericScrollResult;
import io.datahubproject.openapi.v2.models.GenericRelationship;
import io.swagger.v3.oas.annotations.Operation;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

public abstract class GenericRelationshipController {

@Autowired private EntityRegistry entityRegistry;
@Autowired private ElasticSearchGraphService graphService;
@Autowired private AuthorizerChain authorizationChain;

/**
* Returns relationship edges by type
*
* @param relationshipType the relationship type
* @param count number of results
* @param scrollId scrolling id
* @return list of relation edges
*/
@GetMapping(value = "/{relationshipType}", produces = MediaType.APPLICATION_JSON_VALUE)
@Operation(summary = "Scroll relationships of the given type.")
public ResponseEntity<GenericScrollResult<GenericRelationship>> getRelationshipsByType(
@PathVariable("relationshipType") String relationshipType,
@RequestParam(value = "count", defaultValue = "10") Integer count,
@RequestParam(value = "scrollId", required = false) String scrollId) {

Authentication authentication = AuthenticationContext.getAuthentication();
if (!AuthUtil.isAPIAuthorized(authentication, authorizationChain, RELATIONSHIP, READ)) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

RelatedEntitiesScrollResult result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
List.of(relationshipType),
new RelationshipFilter().setDirection(RelationshipDirection.UNDIRECTED),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);

if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
result.getEntities().stream()
.flatMap(
edge ->
Stream.of(
UrnUtils.getUrn(edge.getSourceUrn()),
UrnUtils.getUrn(edge.getDestinationUrn())))
.collect(Collectors.toSet()))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

return ResponseEntity.ok(
GenericScrollResult.<GenericRelationship>builder()
.results(toGenericRelationships(result.getEntities()))
.scrollId(result.getScrollId())
.build());
}

/**
* Returns edges for a given urn
*
* @param relationshipTypes types of edges
* @param direction direction of the edges
* @param count number of results
* @param scrollId scroll id
* @return urn edges
*/
@GetMapping(value = "/{entityName}/{entityUrn}", produces = MediaType.APPLICATION_JSON_VALUE)
@Operation(summary = "Scroll relationships from a given entity.")
public ResponseEntity<GenericScrollResult<GenericRelationship>> getRelationshipsByEntity(
@PathVariable("entityName") String entityName,
@PathVariable("entityUrn") String entityUrn,
@RequestParam(value = "relationshipType[]", required = false, defaultValue = "*")
String[] relationshipTypes,
@RequestParam(value = "direction", defaultValue = "OUTGOING") String direction,
@RequestParam(value = "count", defaultValue = "10") Integer count,
@RequestParam(value = "scrollId", required = false) String scrollId) {

final RelatedEntitiesScrollResult result;

Authentication authentication = AuthenticationContext.getAuthentication();
if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
List.of(UrnUtils.getUrn(entityUrn)))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

switch (RelationshipDirection.valueOf(direction.toUpperCase())) {
case INCOMING -> result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
relationshipTypes.length > 0 && !relationshipTypes[0].equals("*")
? Arrays.stream(relationshipTypes).toList()
: List.of(),
new RelationshipFilter()
.setDirection(RelationshipDirection.UNDIRECTED)
.setOr(QueryUtils.newFilter("destination.urn", entityUrn).getOr()),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);
case OUTGOING -> result =
graphService.scrollRelatedEntities(
null,
null,
null,
null,
relationshipTypes.length > 0 && !relationshipTypes[0].equals("*")
? Arrays.stream(relationshipTypes).toList()
: List.of(),
new RelationshipFilter()
.setDirection(RelationshipDirection.UNDIRECTED)
.setOr(QueryUtils.newFilter("source.urn", entityUrn).getOr()),
Edge.EDGE_SORT_CRITERION,
scrollId,
count,
null,
null);
default -> throw new IllegalArgumentException("Direction must be INCOMING or OUTGOING");
}

if (!AuthUtil.isAPIAuthorizedUrns(
authentication,
authorizationChain,
RELATIONSHIP,
READ,
result.getEntities().stream()
.flatMap(
edge ->
Stream.of(
UrnUtils.getUrn(edge.getSourceUrn()),
UrnUtils.getUrn(edge.getDestinationUrn())))
.collect(Collectors.toSet()))) {
throw new UnauthorizedException(
authentication.getActor().toUrnStr()
+ " is unauthorized to "
+ READ
+ " "
+ RELATIONSHIP);
}

return ResponseEntity.ok(
GenericScrollResult.<GenericRelationship>builder()
.results(toGenericRelationships(result.getEntities()))
.scrollId(result.getScrollId())
.build());
}

private List<GenericRelationship> toGenericRelationships(List<RelatedEntities> relatedEntities) {
return relatedEntities.stream()
.map(
result -> {
Urn source = UrnUtils.getUrn(result.getSourceUrn());
Urn dest = UrnUtils.getUrn(result.getDestinationUrn());
return GenericRelationship.builder()
.relationshipType(result.getRelationshipType())
.source(GenericRelationship.GenericNode.fromUrn(source))
.destination(GenericRelationship.GenericNode.fromUrn(dest))
.build();
})
.collect(Collectors.toList());
}
}
Loading
Loading