diff --git a/.changesets/feat_tninesling_cost_directives.md b/.changesets/feat_tninesling_cost_directives.md new file mode 100644 index 0000000000..e7994edc84 --- /dev/null +++ b/.changesets/feat_tninesling_cost_directives.md @@ -0,0 +1,28 @@ +### Account for demand control directives when scoring operations ([PR #5777](https://github.com/apollographql/router/pull/5777)) + +When scoring operations in the demand control plugin, utilize applications of `@cost` and `@listSize` from the supergraph schema to make better cost estimates. + +For expensive resolvers, the `@cost` directive can override the default weights in the cost calculation. + +```graphql +type Product { + id: ID! + name: String + expensiveField: Int @cost(weight: 20) +} +``` + +Additionally, if a list field's length differs significantly from the globally-configured list size, the `@listSize` directive can provide a tighter size estimate. + +```graphql +type Magazine { + # This is assumed to always return 5 items + headlines: [Article] @listSize(assumedSize: 5) + + # This is estimated to return as many items as are requested by the parameter named "first" + getPage(first: Int!, after: ID!): [Article] + @listSize(slicingArguments: ["first"]) +} +``` + +By [@tninesling](https://github.com/tninesling) in https://github.com/apollographql/router/pull/5777 diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/directives.rs b/apollo-router/src/plugins/demand_control/cost_calculator/directives.rs index b3f3afe372..c4dcc36b00 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/directives.rs +++ b/apollo-router/src/plugins/demand_control/cost_calculator/directives.rs @@ -1,13 +1,112 @@ +use ahash::HashMap; +use ahash::HashMapExt; +use ahash::HashSet; +use apollo_compiler::ast::DirectiveList; +use apollo_compiler::ast::FieldDefinition; +use apollo_compiler::ast::InputValueDefinition; use apollo_compiler::ast::NamedType; use apollo_compiler::executable::Field; use apollo_compiler::executable::SelectionSet; +use apollo_compiler::name; use apollo_compiler::parser::Parser; +use apollo_compiler::schema::ExtendedType; use apollo_compiler::validation::Valid; +use apollo_compiler::Name; use apollo_compiler::Schema; +use apollo_federation::link::spec::APOLLO_SPEC_DOMAIN; +use apollo_federation::link::Link; use tower::BoxError; use super::DemandControlError; +const COST_DIRECTIVE_NAME: Name = name!("cost"); +const COST_DIRECTIVE_DEFAULT_NAME: Name = name!("federation__cost"); +const COST_DIRECTIVE_WEIGHT_ARGUMENT_NAME: Name = name!("weight"); + +const LIST_SIZE_DIRECTIVE_NAME: Name = name!("listSize"); +const LIST_SIZE_DIRECTIVE_DEFAULT_NAME: Name = name!("federation__listSize"); +const LIST_SIZE_DIRECTIVE_ASSUMED_SIZE_ARGUMENT_NAME: Name = name!("assumedSize"); +const LIST_SIZE_DIRECTIVE_SLICING_ARGUMENTS_ARGUMENT_NAME: Name = name!("slicingArguments"); +const LIST_SIZE_DIRECTIVE_SIZED_FIELDS_ARGUMENT_NAME: Name = name!("sizedFields"); +const LIST_SIZE_DIRECTIVE_REQUIRE_ONE_SLICING_ARGUMENT_ARGUMENT_NAME: Name = + name!("requireOneSlicingArgument"); + +pub(in crate::plugins::demand_control) fn get_apollo_directive_names( + schema: &Schema, +) -> HashMap { + let mut hm: HashMap = HashMap::new(); + for directive in &schema.schema_definition.directives { + if directive.name.as_str() == "link" { + if let Ok(link) = Link::from_directive_application(directive) { + if link.url.identity.domain != APOLLO_SPEC_DOMAIN { + continue; + } + for import in link.imports { + hm.insert(import.element.clone(), import.imported_name().clone()); + } + } + } + } + hm +} + +pub(in crate::plugins::demand_control) struct CostDirective { + weight: i32, +} + +impl CostDirective { + pub(in crate::plugins::demand_control) fn weight(&self) -> f64 { + self.weight as f64 + } + + pub(in crate::plugins::demand_control) fn from_argument( + directive_name_map: &HashMap, + argument: &InputValueDefinition, + ) -> Option { + Self::from_directives(directive_name_map, &argument.directives) + } + + pub(in crate::plugins::demand_control) fn from_field( + directive_name_map: &HashMap, + field: &FieldDefinition, + ) -> Option { + Self::from_directives(directive_name_map, &field.directives) + } + + pub(in crate::plugins::demand_control) fn from_type( + directive_name_map: &HashMap, + ty: &ExtendedType, + ) -> Option { + Self::from_schema_directives(directive_name_map, ty.directives()) + } + + fn from_directives( + directive_name_map: &HashMap, + directives: &DirectiveList, + ) -> Option { + directive_name_map + .get(&COST_DIRECTIVE_NAME) + .and_then(|name| directives.get(name)) + .or(directives.get(&COST_DIRECTIVE_DEFAULT_NAME)) + .and_then(|cost| cost.argument_by_name(&COST_DIRECTIVE_WEIGHT_ARGUMENT_NAME)) + .and_then(|weight| weight.to_i32()) + .map(|weight| Self { weight }) + } + + pub(in crate::plugins::demand_control) fn from_schema_directives( + directive_name_map: &HashMap, + directives: &apollo_compiler::schema::DirectiveList, + ) -> Option { + directive_name_map + .get(&COST_DIRECTIVE_NAME) + .and_then(|name| directives.get(name)) + .or(directives.get(&COST_DIRECTIVE_DEFAULT_NAME)) + .and_then(|cost| cost.argument_by_name(&COST_DIRECTIVE_WEIGHT_ARGUMENT_NAME)) + .and_then(|weight| weight.to_i32()) + .map(|weight| Self { weight }) + } +} + pub(in crate::plugins::demand_control) struct IncludeDirective { pub(in crate::plugins::demand_control) is_included: bool, } @@ -27,31 +126,142 @@ impl IncludeDirective { } } +pub(in crate::plugins::demand_control) struct ListSizeDirective<'schema> { + pub(in crate::plugins::demand_control) expected_size: Option, + pub(in crate::plugins::demand_control) sized_fields: Option>, +} + +impl<'schema> ListSizeDirective<'schema> { + pub(in crate::plugins::demand_control) fn size_of(&self, field: &Field) -> Option { + if self + .sized_fields + .as_ref() + .is_some_and(|sf| sf.contains(field.name.as_str())) + { + self.expected_size + } else { + None + } + } +} + +/// The `@listSize` directive from a field definition, which can be converted to +/// `ListSizeDirective` with a concrete field from a request. +pub(in crate::plugins::demand_control) struct DefinitionListSizeDirective { + assumed_size: Option, + slicing_argument_names: Option>, + sized_fields: Option>, + require_one_slicing_argument: bool, +} + +impl DefinitionListSizeDirective { + pub(in crate::plugins::demand_control) fn from_field_definition( + directive_name_map: &HashMap, + definition: &FieldDefinition, + ) -> Result, DemandControlError> { + let directive = directive_name_map + .get(&LIST_SIZE_DIRECTIVE_NAME) + .and_then(|name| definition.directives.get(name)) + .or(definition.directives.get(&LIST_SIZE_DIRECTIVE_DEFAULT_NAME)); + if let Some(directive) = directive { + let assumed_size = directive + .argument_by_name(&LIST_SIZE_DIRECTIVE_ASSUMED_SIZE_ARGUMENT_NAME) + .and_then(|arg| arg.to_i32()); + let slicing_argument_names = directive + .argument_by_name(&LIST_SIZE_DIRECTIVE_SLICING_ARGUMENTS_ARGUMENT_NAME) + .and_then(|arg| arg.as_list()) + .map(|arg_list| { + arg_list + .iter() + .flat_map(|arg| arg.as_str()) + .map(String::from) + .collect() + }); + let sized_fields = directive + .argument_by_name(&LIST_SIZE_DIRECTIVE_SIZED_FIELDS_ARGUMENT_NAME) + .and_then(|arg| arg.as_list()) + .map(|arg_list| { + arg_list + .iter() + .flat_map(|arg| arg.as_str()) + .map(String::from) + .collect() + }); + let require_one_slicing_argument = directive + .argument_by_name(&LIST_SIZE_DIRECTIVE_REQUIRE_ONE_SLICING_ARGUMENT_ARGUMENT_NAME) + .and_then(|arg| arg.to_bool()) + .unwrap_or(true); + + Ok(Some(Self { + assumed_size, + slicing_argument_names, + sized_fields, + require_one_slicing_argument, + })) + } else { + Ok(None) + } + } + + pub(in crate::plugins::demand_control) fn with_field( + &self, + field: &Field, + ) -> Result { + let mut slicing_arguments: HashMap<&str, i32> = HashMap::new(); + if let Some(slicing_argument_names) = self.slicing_argument_names.as_ref() { + // First, collect the default values for each argument + for argument in &field.definition.arguments { + if slicing_argument_names.contains(argument.name.as_str()) { + if let Some(numeric_value) = + argument.default_value.as_ref().and_then(|v| v.to_i32()) + { + slicing_arguments.insert(&argument.name, numeric_value); + } + } + } + // Then, overwrite any default values with the actual values passed in the query + for argument in &field.arguments { + if slicing_argument_names.contains(argument.name.as_str()) { + if let Some(numeric_value) = argument.value.to_i32() { + slicing_arguments.insert(&argument.name, numeric_value); + } + } + } + + if self.require_one_slicing_argument && slicing_arguments.len() != 1 { + return Err(DemandControlError::QueryParseFailure(format!( + "Exactly one slicing argument is required, but found {}", + slicing_arguments.len() + ))); + } + } + + let expected_size = slicing_arguments + .values() + .max() + .cloned() + .or(self.assumed_size); + + Ok(ListSizeDirective { + expected_size, + sized_fields: self + .sized_fields + .as_ref() + .map(|set| set.iter().map(|s| s.as_str()).collect()), + }) + } +} + pub(in crate::plugins::demand_control) struct RequiresDirective { pub(in crate::plugins::demand_control) fields: SelectionSet, } impl RequiresDirective { - pub(in crate::plugins::demand_control) fn from_field( - field: &Field, + pub(in crate::plugins::demand_control) fn from_field_definition( + definition: &FieldDefinition, parent_type_name: &NamedType, schema: &Valid, ) -> Result, DemandControlError> { - // When a user marks a subgraph schema field with `@requires`, the composition process - // replaces `@requires(field: "")` with `@join__field(requires: "")`. - // - // Note we cannot use `field.definition` in this case: The operation executes against the - // API schema, so its definition pointers point into the API schema. To find the - // `@join__field()` directive, we must instead look up the field on the type with the same - // name in the supergraph. - let definition = schema - .type_field(parent_type_name, &field.name) - .map_err(|_err| { - DemandControlError::QueryParseFailure(format!( - "Could not find the API schema type {}.{} in the supergraph. This looks like a bug", - parent_type_name, &field.name - )) - })?; let requires_arg = definition .directives .get("join__field") diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query.graphql index 86a01356e7..c8494f9697 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query.graphql +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query.graphql @@ -1,3 +1,5 @@ query BasicInputObjectQuery { - getScalarByObject(args: { inner: { id: 1 } }) + getScalarByObject( + args: { inner: { id: 1 }, listOfInner: [{ id: 2 }, { id: 3 }] } + ) } diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query_2.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query_2.graphql new file mode 100644 index 0000000000..26a1a06623 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_query_2.graphql @@ -0,0 +1,8 @@ +query BasicInputObjectQuery2 { + getObjectsByObject( + args: { inner: { id: 1 }, listOfInner: [{ id: 2 }, { id: 3 }] } + ) { + field1 + field2 + } +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_response.json b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_response.json new file mode 100644 index 0000000000..092377bf7f --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_input_object_response.json @@ -0,0 +1,9 @@ +{ + "data": { + "getObjectsByObject": [ + { "field1": 1, "field2": "one" }, + { "field1": 2, "field2": "two" }, + { "field1": 3, "field2": "three" } + ] + } +} \ No newline at end of file diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_schema.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_schema.graphql index d613012b0d..17f3046414 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_schema.graphql +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/basic_schema.graphql @@ -7,6 +7,7 @@ type Query { someUnion: UnionOfObjectTypes someObjects: [FirstObjectType] intList: [Int] + getObjectsByObject(args: OuterInput): [SecondObjectType] } type Mutation { @@ -35,4 +36,6 @@ input InnerInput { input OuterInput { inner: InnerInput + inner2: InnerInput + listOfInner: [InnerInput!] } diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query.graphql new file mode 100644 index 0000000000..751c8a005e --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query.graphql @@ -0,0 +1,20 @@ +fragment Items on SizedField { + items { + id + } +} + +{ + fieldWithCost + argWithCost(arg: 3) + enumWithCost + inputWithCost(someInput: { somethingWithCost: 10 }) + scalarWithCost + objectWithCost { + id + } + fieldWithListSize + fieldWithDynamicListSize(first: 5) { + ...Items + } +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query_with_default_slicing_argument.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query_with_default_slicing_argument.graphql new file mode 100644 index 0000000000..fb50e08fef --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_query_with_default_slicing_argument.graphql @@ -0,0 +1,20 @@ +fragment Items on SizedField { + items { + id + } +} + +{ + fieldWithCost + argWithCost(arg: 3) + enumWithCost + inputWithCost(someInput: { somethingWithCost: 10 }) + scalarWithCost + objectWithCost { + id + } + fieldWithListSize + fieldWithDynamicListSize { + ...Items + } +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_response.json b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_response.json new file mode 100644 index 0000000000..664a2684e6 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_response.json @@ -0,0 +1,24 @@ +{ + "data": { + "fieldWithCost": 1, + "argWithCost": 2, + "enumWithCost": "A", + "inputWithCost": 3, + "scalarWithCost": 4, + "objectWithCost": { + "id": 5 + }, + "fieldWithListSize": [ + "first", + "second", + "third" + ], + "fieldWithDynamicListSize": { + "items": [ + { "id": 6 }, + { "id": 7 }, + { "id": 8 } + ] + } + } +} \ No newline at end of file diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema.graphql new file mode 100644 index 0000000000..d966512be1 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema.graphql @@ -0,0 +1,154 @@ +schema + @link(url: "https://specs.apollo.dev/link/v1.0") + @link(url: "https://specs.apollo.dev/join/v0.5", for: EXECUTION) + @link( + url: "https://specs.apollo.dev/cost/v0.1" + import: ["@cost", "@listSize"] + ) { + query: Query +} + +directive @cost( + weight: Int! +) on ARGUMENT_DEFINITION | ENUM | FIELD_DEFINITION | INPUT_FIELD_DEFINITION | OBJECT | SCALAR + +directive @cost__listSize( + assumedSize: Int + slicingArguments: [String!] + sizedFields: [String!] + requireOneSlicingArgument: Boolean = true +) on FIELD_DEFINITION + +directive @join__directive( + graphs: [join__Graph!] + name: String! + args: join__DirectiveArguments +) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION + +directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE + +directive @join__field( + graph: join__Graph + requires: join__FieldSet + provides: join__FieldSet + type: String + external: Boolean + override: String + usedOverridden: Boolean + overrideLabel: String + contextArguments: [join__ContextArgument!] +) repeatable on FIELD_DEFINITION | INPUT_FIELD_DEFINITION + +directive @join__graph(name: String!, url: String!) on ENUM_VALUE + +directive @join__implements( + graph: join__Graph! + interface: String! +) repeatable on OBJECT | INTERFACE + +directive @join__type( + graph: join__Graph! + key: join__FieldSet + extension: Boolean! = false + resolvable: Boolean! = true + isInterfaceObject: Boolean! = false +) repeatable on OBJECT | INTERFACE | UNION | ENUM | INPUT_OBJECT | SCALAR + +directive @join__unionMember( + graph: join__Graph! + member: String! +) repeatable on UNION + +directive @link( + url: String + as: String + for: link__Purpose + import: [link__Import] +) repeatable on SCHEMA + +directive @listSize( + assumedSize: Int + slicingArguments: [String!] + sizedFields: [String!] + requireOneSlicingArgument: Boolean = true +) on FIELD_DEFINITION + +type A @join__type(graph: SUBGRAPHWITHLISTSIZE) { + id: ID +} + +enum AorB @join__type(graph: SUBGRAPHWITHCOST) @cost(weight: 15) { + A @join__enumValue(graph: SUBGRAPHWITHCOST) + B @join__enumValue(graph: SUBGRAPHWITHCOST) +} + +scalar ExpensiveInt @join__type(graph: SUBGRAPHWITHCOST) @cost(weight: 30) + +type ExpensiveObject @join__type(graph: SUBGRAPHWITHCOST) @cost(weight: 40) { + id: ID +} + +input InputTypeWithCost @join__type(graph: SUBGRAPHWITHCOST) { + somethingWithCost: Int @cost(weight: 20) +} + +input join__ContextArgument { + name: String! + type: String! + context: String! + selection: join__FieldValue! +} + +scalar join__DirectiveArguments + +scalar join__FieldSet + +scalar join__FieldValue + +enum join__Graph { + SUBGRAPHWITHCOST + @join__graph(name: "subgraphWithCost", url: "http://localhost:4001") + SUBGRAPHWITHLISTSIZE + @join__graph(name: "subgraphWithListSize", url: "http://localhost:4002") +} + +scalar link__Import + +enum link__Purpose { + """ + `SECURITY` features provide metadata necessary to securely resolve fields. + """ + SECURITY + + """ + `EXECUTION` features provide metadata necessary for operation execution. + """ + EXECUTION +} + +type Query + @join__type(graph: SUBGRAPHWITHCOST) + @join__type(graph: SUBGRAPHWITHLISTSIZE) { + fieldWithCost: Int @join__field(graph: SUBGRAPHWITHCOST) @cost(weight: 5) + argWithCost(arg: Int @cost(weight: 10)): Int + @join__field(graph: SUBGRAPHWITHCOST) + enumWithCost: AorB @join__field(graph: SUBGRAPHWITHCOST) + inputWithCost(someInput: InputTypeWithCost): Int + @join__field(graph: SUBGRAPHWITHCOST) + scalarWithCost: ExpensiveInt @join__field(graph: SUBGRAPHWITHCOST) + objectWithCost: ExpensiveObject @join__field(graph: SUBGRAPHWITHCOST) + fieldWithListSize: [String!] + @join__field(graph: SUBGRAPHWITHLISTSIZE) + @listSize(assumedSize: 2000, requireOneSlicingArgument: false) + fieldWithDynamicListSize(first: Int = 10): SizedField + @join__field(graph: SUBGRAPHWITHLISTSIZE) + @listSize( + slicingArguments: ["first"] + sizedFields: ["items"] + requireOneSlicingArgument: true + ) +} + +type SizedField @join__type(graph: SUBGRAPHWITHLISTSIZE) { + items: [A] +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema_with_renamed_directives.graphql b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema_with_renamed_directives.graphql new file mode 100644 index 0000000000..1d1f17263d --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/fixtures/custom_cost_schema_with_renamed_directives.graphql @@ -0,0 +1,163 @@ +schema + @link(url: "https://specs.apollo.dev/link/v1.0") + @link(url: "https://specs.apollo.dev/join/v0.5", for: EXECUTION) + @link( + url: "https://specs.apollo.dev/cost/v0.1" + import: [ + { name: "@cost", as: "@renamedCost" } + { name: "@listSize", as: "@renamedListSize" } + ] + ) { + query: Query +} + +directive @cost__listSize( + assumedSize: Int + slicingArguments: [String!] + sizedFields: [String!] + requireOneSlicingArgument: Boolean = true +) on FIELD_DEFINITION + +directive @join__directive( + graphs: [join__Graph!] + name: String! + args: join__DirectiveArguments +) repeatable on SCHEMA | OBJECT | INTERFACE | FIELD_DEFINITION + +directive @join__enumValue(graph: join__Graph!) repeatable on ENUM_VALUE + +directive @join__field( + graph: join__Graph + requires: join__FieldSet + provides: join__FieldSet + type: String + external: Boolean + override: String + usedOverridden: Boolean + overrideLabel: String + contextArguments: [join__ContextArgument!] +) repeatable on FIELD_DEFINITION | INPUT_FIELD_DEFINITION + +directive @join__graph(name: String!, url: String!) on ENUM_VALUE + +directive @join__implements( + graph: join__Graph! + interface: String! +) repeatable on OBJECT | INTERFACE + +directive @join__type( + graph: join__Graph! + key: join__FieldSet + extension: Boolean! = false + resolvable: Boolean! = true + isInterfaceObject: Boolean! = false +) repeatable on OBJECT | INTERFACE | UNION | ENUM | INPUT_OBJECT | SCALAR + +directive @join__unionMember( + graph: join__Graph! + member: String! +) repeatable on UNION + +directive @link( + url: String + as: String + for: link__Purpose + import: [link__Import] +) repeatable on SCHEMA + +directive @renamedCost( + weight: Int! +) on ARGUMENT_DEFINITION | ENUM | FIELD_DEFINITION | INPUT_FIELD_DEFINITION | OBJECT | SCALAR + +directive @renamedListSize( + assumedSize: Int + slicingArguments: [String!] + sizedFields: [String!] + requireOneSlicingArgument: Boolean = true +) on FIELD_DEFINITION + +type A @join__type(graph: SUBGRAPHWITHLISTSIZE) { + id: ID +} + +enum AorB @join__type(graph: SUBGRAPHWITHCOST) @renamedCost(weight: 15) { + A @join__enumValue(graph: SUBGRAPHWITHCOST) + B @join__enumValue(graph: SUBGRAPHWITHCOST) +} + +scalar ExpensiveInt + @join__type(graph: SUBGRAPHWITHCOST) + @renamedCost(weight: 30) + +type ExpensiveObject + @join__type(graph: SUBGRAPHWITHCOST) + @renamedCost(weight: 40) { + id: ID +} + +input InputTypeWithCost @join__type(graph: SUBGRAPHWITHCOST) { + somethingWithCost: Int @renamedCost(weight: 20) +} + +input join__ContextArgument { + name: String! + type: String! + context: String! + selection: join__FieldValue! +} + +scalar join__DirectiveArguments + +scalar join__FieldSet + +scalar join__FieldValue + +enum join__Graph { + SUBGRAPHWITHCOST + @join__graph(name: "subgraphWithCost", url: "http://localhost:4001") + SUBGRAPHWITHLISTSIZE + @join__graph(name: "subgraphWithListSize", url: "http://localhost:4002") +} + +scalar link__Import + +enum link__Purpose { + """ + `SECURITY` features provide metadata necessary to securely resolve fields. + """ + SECURITY + + """ + `EXECUTION` features provide metadata necessary for operation execution. + """ + EXECUTION +} + +type Query + @join__type(graph: SUBGRAPHWITHCOST) + @join__type(graph: SUBGRAPHWITHLISTSIZE) { + fieldWithCost: Int + @join__field(graph: SUBGRAPHWITHCOST) + @renamedCost(weight: 5) + argWithCost(arg: Int @renamedCost(weight: 10)): Int + @join__field(graph: SUBGRAPHWITHCOST) + enumWithCost: AorB @join__field(graph: SUBGRAPHWITHCOST) + inputWithCost(someInput: InputTypeWithCost): Int + @join__field(graph: SUBGRAPHWITHCOST) + scalarWithCost: ExpensiveInt @join__field(graph: SUBGRAPHWITHCOST) + objectWithCost: ExpensiveObject @join__field(graph: SUBGRAPHWITHCOST) + fieldWithListSize: [String!] + @join__field(graph: SUBGRAPHWITHLISTSIZE) + @renamedListSize(assumedSize: 2000, requireOneSlicingArgument: false) + fieldWithDynamicListSize(first: Int = 10): SizedField + @join__field(graph: SUBGRAPHWITHLISTSIZE) + @renamedListSize( + slicingArguments: ["first"] + sizedFields: ["items"] + requireOneSlicingArgument: true + ) +} + +type SizedField @join__type(graph: SUBGRAPHWITHLISTSIZE) { + items: [A] +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/mod.rs b/apollo-router/src/plugins/demand_control/cost_calculator/mod.rs index 290ce4dbe4..a534f91a94 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/mod.rs +++ b/apollo-router/src/plugins/demand_control/cost_calculator/mod.rs @@ -1,4 +1,5 @@ mod directives; +pub(in crate::plugins::demand_control) mod schema; pub(crate) mod static_cost; use crate::plugins::demand_control::DemandControlError; diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/schema.rs b/apollo-router/src/plugins/demand_control/cost_calculator/schema.rs new file mode 100644 index 0000000000..6a46ee9fe9 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/cost_calculator/schema.rs @@ -0,0 +1,180 @@ +use std::ops::Deref; +use std::sync::Arc; + +use ahash::HashMap; +use ahash::HashMapExt; +use apollo_compiler::schema::ExtendedType; +use apollo_compiler::validation::Valid; +use apollo_compiler::Name; +use apollo_compiler::Schema; + +use super::directives::get_apollo_directive_names; +use super::directives::CostDirective; +use super::directives::DefinitionListSizeDirective as ListSizeDirective; +use super::directives::RequiresDirective; +use crate::plugins::demand_control::DemandControlError; + +pub(crate) struct DemandControlledSchema { + directive_name_map: HashMap, + inner: Arc>, + type_field_cost_directives: HashMap>, + type_field_list_size_directives: HashMap>, + type_field_requires_directives: HashMap>, +} + +impl DemandControlledSchema { + pub(crate) fn new(schema: Arc>) -> Result { + let directive_name_map = get_apollo_directive_names(&schema); + + let mut type_field_cost_directives: HashMap> = + HashMap::new(); + let mut type_field_list_size_directives: HashMap> = + HashMap::new(); + let mut type_field_requires_directives: HashMap> = + HashMap::new(); + + for (type_name, type_) in &schema.types { + let field_cost_directives = type_field_cost_directives + .entry(type_name.clone()) + .or_default(); + let field_list_size_directives = type_field_list_size_directives + .entry(type_name.clone()) + .or_default(); + let field_requires_directives = type_field_requires_directives + .entry(type_name.clone()) + .or_default(); + + match type_ { + ExtendedType::Interface(ty) => { + for field_name in ty.fields.keys() { + let field_definition = schema.type_field(type_name, field_name)?; + let field_type = schema.types.get(field_definition.ty.inner_named_type()).ok_or_else(|| { + DemandControlError::QueryParseFailure(format!( + "Field {} was found in query, but its type is missing from the schema.", + field_name + )) + })?; + + if let Some(cost_directive) = + CostDirective::from_field(&directive_name_map, field_definition) + .or(CostDirective::from_type(&directive_name_map, field_type)) + { + field_cost_directives.insert(field_name.clone(), cost_directive); + } + + if let Some(list_size_directive) = ListSizeDirective::from_field_definition( + &directive_name_map, + field_definition, + )? { + field_list_size_directives + .insert(field_name.clone(), list_size_directive); + } + + if let Some(requires_directive) = RequiresDirective::from_field_definition( + field_definition, + type_name, + &schema, + )? { + field_requires_directives + .insert(field_name.clone(), requires_directive); + } + } + } + ExtendedType::Object(ty) => { + for field_name in ty.fields.keys() { + let field_definition = schema.type_field(type_name, field_name)?; + let field_type = schema.types.get(field_definition.ty.inner_named_type()).ok_or_else(|| { + DemandControlError::QueryParseFailure(format!( + "Field {} was found in query, but its type is missing from the schema.", + field_name + )) + })?; + + if let Some(cost_directive) = + CostDirective::from_field(&directive_name_map, field_definition) + .or(CostDirective::from_type(&directive_name_map, field_type)) + { + field_cost_directives.insert(field_name.clone(), cost_directive); + } + + if let Some(list_size_directive) = ListSizeDirective::from_field_definition( + &directive_name_map, + field_definition, + )? { + field_list_size_directives + .insert(field_name.clone(), list_size_directive); + } + + if let Some(requires_directive) = RequiresDirective::from_field_definition( + field_definition, + type_name, + &schema, + )? { + field_requires_directives + .insert(field_name.clone(), requires_directive); + } + } + } + _ => { + // Other types don't have fields + } + } + } + + Ok(Self { + directive_name_map, + inner: schema, + type_field_cost_directives, + type_field_list_size_directives, + type_field_requires_directives, + }) + } + + pub(in crate::plugins::demand_control) fn directive_name_map(&self) -> &HashMap { + &self.directive_name_map + } + + pub(in crate::plugins::demand_control) fn type_field_cost_directive( + &self, + type_name: &str, + field_name: &str, + ) -> Option<&CostDirective> { + self.type_field_cost_directives + .get(type_name)? + .get(field_name) + } + + pub(in crate::plugins::demand_control) fn type_field_list_size_directive( + &self, + type_name: &str, + field_name: &str, + ) -> Option<&ListSizeDirective> { + self.type_field_list_size_directives + .get(type_name)? + .get(field_name) + } + + pub(in crate::plugins::demand_control) fn type_field_requires_directive( + &self, + type_name: &str, + field_name: &str, + ) -> Option<&RequiresDirective> { + self.type_field_requires_directives + .get(type_name)? + .get(field_name) + } +} + +impl AsRef> for DemandControlledSchema { + fn as_ref(&self) -> &Valid { + &self.inner + } +} + +impl Deref for DemandControlledSchema { + type Target = Schema; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs index 7601ba71e5..4f2e585db3 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs +++ b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use ahash::HashMap; +use apollo_compiler::ast; use apollo_compiler::ast::InputValueDefinition; use apollo_compiler::ast::NamedType; use apollo_compiler::executable::ExecutableDocument; @@ -9,18 +11,19 @@ use apollo_compiler::executable::InlineFragment; use apollo_compiler::executable::Operation; use apollo_compiler::executable::Selection; use apollo_compiler::executable::SelectionSet; -use apollo_compiler::validation::Valid; -use apollo_compiler::Schema; +use apollo_compiler::schema::ExtendedType; +use apollo_compiler::Node; use serde_json_bytes::Value; use super::directives::IncludeDirective; -use super::directives::RequiresDirective; use super::directives::SkipDirective; +use super::schema::DemandControlledSchema; use super::DemandControlError; use crate::graphql::Response; use crate::graphql::ResponseVisitor; +use crate::plugins::demand_control::cost_calculator::directives::CostDirective; +use crate::plugins::demand_control::cost_calculator::directives::ListSizeDirective; use crate::query_planner::fetch::SubgraphOperation; -use crate::query_planner::fetch::SubgraphSchemas; use crate::query_planner::DeferredNode; use crate::query_planner::PlanNode; use crate::query_planner::Primary; @@ -28,13 +31,74 @@ use crate::query_planner::QueryPlan; pub(crate) struct StaticCostCalculator { list_size: u32, - subgraph_schemas: Arc, + supergraph_schema: Arc, + subgraph_schemas: Arc>, +} + +fn score_argument( + argument: &apollo_compiler::ast::Value, + argument_definition: &Node, + schema: &DemandControlledSchema, +) -> Result { + let cost_directive = + CostDirective::from_argument(schema.directive_name_map(), argument_definition); + let ty = schema + .types + .get(argument_definition.ty.inner_named_type()) + .ok_or_else(|| { + DemandControlError::QueryParseFailure(format!( + "Argument {} was found in query, but its type ({}) was not found in the schema", + argument_definition.name, + argument_definition.ty.inner_named_type() + )) + })?; + + match (argument, ty) { + (_, ExtendedType::Interface(_)) + | (_, ExtendedType::Object(_)) + | (_, ExtendedType::Union(_)) => Err(DemandControlError::QueryParseFailure( + format!( + "Argument {} has type {}, but objects, interfaces, and unions are disallowed in this position", + argument_definition.name, + argument_definition.ty.inner_named_type() + ) + )), + + (ast::Value::Object(inner_args), ExtendedType::InputObject(inner_arg_defs)) => { + let mut cost = cost_directive.map_or(1.0, |cost| cost.weight()); + for (arg_name, arg_val) in inner_args { + let arg_def = inner_arg_defs.fields.get(arg_name).ok_or_else(|| { + DemandControlError::QueryParseFailure(format!( + "Argument {} was found in query, but its type ({}) was not found in the schema", + argument_definition.name, + argument_definition.ty.inner_named_type() + )) + })?; + cost += score_argument(arg_val, arg_def, schema)?; + } + Ok(cost) + } + (ast::Value::List(inner_args), _) => { + let mut cost = cost_directive.map_or(0.0, |cost| cost.weight()); + for arg_val in inner_args { + cost += score_argument(arg_val, argument_definition, schema)?; + } + Ok(cost) + } + (ast::Value::Null, _) => Ok(0.0), + _ => Ok(cost_directive.map_or(0.0, |cost| cost.weight())) + } } impl StaticCostCalculator { - pub(crate) fn new(subgraph_schemas: Arc, list_size: u32) -> Self { + pub(crate) fn new( + supergraph_schema: Arc, + subgraph_schemas: Arc>, + list_size: u32, + ) -> Self { Self { list_size, + supergraph_schema, subgraph_schemas, } } @@ -61,14 +125,18 @@ impl StaticCostCalculator { &self, field: &Field, parent_type: &NamedType, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, + list_size_from_upstream: Option, ) -> Result { if StaticCostCalculator::skipped_by_directives(field) { return Ok(0.0); } + // We need to look up the `FieldDefinition` from the supergraph schema instead of using `field.definition` + // because `field.definition` was generated from the API schema, which strips off the directives we need. + let definition = schema.type_field(parent_type, &field.name)?; let ty = field.inner_type_def(schema).ok_or_else(|| { DemandControlError::QueryParseFailure(format!( "Field {} was found in query, but its type is missing from the schema.", @@ -76,17 +144,32 @@ impl StaticCostCalculator { )) })?; - // Determine how many instances we're scoring. If there's no user-provided - // information, assume lists have 100 items. - let instance_count = if field.ty().is_list() { - self.list_size as f64 + let list_size_directive = + match schema.type_field_list_size_directive(parent_type, &field.name) { + Some(dir) => dir.with_field(field).map(Some), + None => Ok(None), + }?; + let instance_count = if !field.ty().is_list() { + 1 + } else if let Some(value) = list_size_from_upstream { + // This is a sized field whose length is defined by the `@listSize` directive on the parent field + value + } else if let Some(expected_size) = list_size_directive + .as_ref() + .and_then(|dir| dir.expected_size) + { + expected_size } else { - 1.0 + self.list_size as i32 }; // Determine the cost for this particular field. Scalars are free, non-scalars are not. // For fields with selections, add in the cost of the selections as well. - let mut type_cost = if ty.is_interface() || ty.is_object() || ty.is_union() { + let mut type_cost = if let Some(cost_directive) = + schema.type_field_cost_directive(parent_type, &field.name) + { + cost_directive.weight() + } else if ty.is_interface() || ty.is_object() || ty.is_union() { 1.0 } else { 0.0 @@ -97,10 +180,19 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + list_size_directive.as_ref(), )?; - for argument in &field.definition.arguments { - type_cost += Self::score_argument(argument, schema)?; + let mut arguments_cost = 0.0; + for argument in &field.arguments { + let argument_definition = + definition.argument_by_name(&argument.name).ok_or_else(|| { + DemandControlError::QueryParseFailure(format!( + "Argument {} of field {} is missing a definition in the schema", + argument.name, field.name + )) + })?; + arguments_cost += score_argument(&argument.value, argument_definition, schema)?; } let mut requirements_cost = 0.0; @@ -108,25 +200,28 @@ impl StaticCostCalculator { // If the field is marked with `@requires`, the required selection may not be included // in the query's selection. Adding that requirement's cost to the field ensures it's // accounted for. - let requirements = - RequiresDirective::from_field(field, parent_type, schema)?.map(|d| d.fields); + let requirements = schema + .type_field_requires_directive(parent_type, &field.name) + .map(|d| &d.fields); if let Some(selection_set) = requirements { requirements_cost = self.score_selection_set( - &selection_set, + selection_set, parent_type, schema, executable, should_estimate_requires, + list_size_directive.as_ref(), )?; } } - let cost = instance_count * type_cost + requirements_cost; + let cost = (instance_count as f64) * type_cost + arguments_cost + requirements_cost; tracing::debug!( - "Field {} cost breakdown: (count) {} * (type cost) {} + (requirements) {} = {}", + "Field {} cost breakdown: (count) {} * (type cost) {} + (arguments) {} + (requirements) {} = {}", field.name, instance_count, type_cost, + arguments_cost, requirements_cost, cost ); @@ -134,47 +229,14 @@ impl StaticCostCalculator { Ok(cost) } - fn score_argument( - argument: &InputValueDefinition, - schema: &Valid, - ) -> Result { - if let Some(ty) = schema.types.get(argument.ty.inner_named_type().as_str()) { - match ty { - apollo_compiler::schema::ExtendedType::InputObject(inner_arguments) => { - let mut cost = 1.0; - for inner_argument in inner_arguments.fields.values() { - cost += Self::score_argument(inner_argument, schema)?; - } - Ok(cost) - } - - apollo_compiler::schema::ExtendedType::Scalar(_) - | apollo_compiler::schema::ExtendedType::Enum(_) => Ok(0.0), - - apollo_compiler::schema::ExtendedType::Object(_) - | apollo_compiler::schema::ExtendedType::Interface(_) - | apollo_compiler::schema::ExtendedType::Union(_) => { - Err(DemandControlError::QueryParseFailure( - format!("Argument {} has type {}, but objects, interfaces, and unions are disallowed in this position", argument.name, argument.ty.inner_named_type()) - )) - } - } - } else { - Err(DemandControlError::QueryParseFailure(format!( - "Argument {} was found in query, but its type ({}) was not found in the schema", - argument.name, - argument.ty.inner_named_type() - ))) - } - } - fn score_fragment_spread( &self, fragment_spread: &FragmentSpread, parent_type: &NamedType, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, + list_size_directive: Option<&ListSizeDirective>, ) -> Result { let fragment = fragment_spread.fragment_def(executable).ok_or_else(|| { DemandControlError::QueryParseFailure(format!( @@ -188,6 +250,7 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + list_size_directive, ) } @@ -195,9 +258,10 @@ impl StaticCostCalculator { &self, inline_fragment: &InlineFragment, parent_type: &NamedType, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, + list_size_directive: Option<&ListSizeDirective>, ) -> Result { self.score_selection_set( &inline_fragment.selection_set, @@ -205,13 +269,14 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + list_size_directive, ) } fn score_operation( &self, operation: &Operation, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, ) -> Result { @@ -230,6 +295,7 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + None, )?; Ok(cost) @@ -239,20 +305,27 @@ impl StaticCostCalculator { &self, selection: &Selection, parent_type: &NamedType, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, + list_size_directive: Option<&ListSizeDirective>, ) -> Result { match selection { - Selection::Field(f) => { - self.score_field(f, parent_type, schema, executable, should_estimate_requires) - } + Selection::Field(f) => self.score_field( + f, + parent_type, + schema, + executable, + should_estimate_requires, + list_size_directive.and_then(|dir| dir.size_of(f)), + ), Selection::FragmentSpread(s) => self.score_fragment_spread( s, parent_type, schema, executable, should_estimate_requires, + list_size_directive, ), Selection::InlineFragment(i) => self.score_inline_fragment( i, @@ -260,6 +333,7 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + list_size_directive, ), } } @@ -268,9 +342,10 @@ impl StaticCostCalculator { &self, selection_set: &SelectionSet, parent_type_name: &NamedType, - schema: &Valid, + schema: &DemandControlledSchema, executable: &ExecutableDocument, should_estimate_requires: bool, + list_size_directive: Option<&ListSizeDirective>, ) -> Result { let mut cost = 0.0; for selection in selection_set.selections.iter() { @@ -280,6 +355,7 @@ impl StaticCostCalculator { schema, executable, should_estimate_requires, + list_size_directive, )?; } Ok(cost) @@ -386,7 +462,7 @@ impl StaticCostCalculator { pub(crate) fn estimated( &self, query: &ExecutableDocument, - schema: &Valid, + schema: &DemandControlledSchema, should_estimate_requires: bool, ) -> Result { let mut cost = 0.0; @@ -408,39 +484,75 @@ impl StaticCostCalculator { request: &ExecutableDocument, response: &Response, ) -> Result { - let mut visitor = ResponseCostCalculator::new(); + let mut visitor = ResponseCostCalculator::new(&self.supergraph_schema); visitor.visit(request, response); Ok(visitor.cost) } } -pub(crate) struct ResponseCostCalculator { +pub(crate) struct ResponseCostCalculator<'a> { pub(crate) cost: f64, + schema: &'a DemandControlledSchema, } -impl ResponseCostCalculator { - pub(crate) fn new() -> Self { - Self { cost: 0.0 } +impl<'schema> ResponseCostCalculator<'schema> { + pub(crate) fn new(schema: &'schema DemandControlledSchema) -> Self { + Self { cost: 0.0, schema } } } -impl ResponseVisitor for ResponseCostCalculator { +impl<'schema> ResponseVisitor for ResponseCostCalculator<'schema> { fn visit_field( &mut self, request: &ExecutableDocument, - _ty: &NamedType, + parent_ty: &NamedType, field: &Field, value: &Value, ) { + self.visit_list_item(request, parent_ty, field, value); + + let definition = self.schema.type_field(parent_ty, &field.name); + for argument in &field.arguments { + if let Ok(Some(argument_definition)) = definition + .as_ref() + .map(|def| def.argument_by_name(&argument.name)) + { + if let Ok(score) = score_argument(&argument.value, argument_definition, self.schema) + { + self.cost += score; + } + } else { + tracing::warn!( + "Failed to get schema definition for argument {} of field {}. The resulting actual cost will be a partial result.", + argument.name, + field.name + ) + } + } + } + + fn visit_list_item( + &mut self, + request: &apollo_compiler::ExecutableDocument, + parent_ty: &apollo_compiler::executable::NamedType, + field: &apollo_compiler::executable::Field, + value: &Value, + ) { + let cost_directive = self + .schema + .type_field_cost_directive(parent_ty, &field.name); + match value { - Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {} + Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => { + self.cost += cost_directive.map_or(0.0, |cost| cost.weight()); + } Value::Array(items) => { for item in items { - self.visit_field(request, field.ty().inner_named_type(), field, item); + self.visit_list_item(request, parent_ty, field, item); } } Value::Object(children) => { - self.cost += 1.0; + self.cost += cost_directive.map_or(1.0, |cost| cost.weight()); self.visit_selections(request, &field.selection_set, children); } } @@ -451,19 +563,26 @@ impl ResponseVisitor for ResponseCostCalculator { mod tests { use std::sync::Arc; + use ahash::HashMapExt; + use apollo_federation::query_plan::query_planner::QueryPlanner; use bytes::Bytes; use test_log::test; - use tower::Service; use super::*; - use crate::query_planner::BridgeQueryPlanner; use crate::services::layers::query_analysis::ParsedDocument; - use crate::services::QueryPlannerContent; - use crate::services::QueryPlannerRequest; use crate::spec; use crate::spec::Query; use crate::Configuration; - use crate::Context; + + impl StaticCostCalculator { + fn rust_planned( + &self, + query_plan: &apollo_federation::query_plan::QueryPlan, + ) -> Result { + let js_planner_node: PlanNode = query_plan.node.as_ref().unwrap().into(); + self.score_plan_node(&js_planner_node) + } + } fn parse_schema_and_operation( schema_str: &str, @@ -479,8 +598,12 @@ mod tests { fn estimated_cost(schema_str: &str, query_str: &str) -> f64 { let (schema, query) = parse_schema_and_operation(schema_str, query_str, &Default::default()); - StaticCostCalculator::new(Default::default(), 100) - .estimated(&query.executable, schema.supergraph_schema(), true) + let schema = + DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(); + let calculator = StaticCostCalculator::new(Arc::new(schema), Default::default(), 100); + + calculator + .estimated(&query.executable, &calculator.supergraph_schema, true) .unwrap() } @@ -494,8 +617,11 @@ mod tests { "query.graphql", ) .unwrap(); - StaticCostCalculator::new(Default::default(), 100) - .estimated(&query, &schema, true) + let schema = DemandControlledSchema::new(Arc::new(schema)).unwrap(); + let calculator = StaticCostCalculator::new(Arc::new(schema), Default::default(), 100); + + calculator + .estimated(&query, &calculator.supergraph_schema, true) .unwrap() } @@ -503,40 +629,59 @@ mod tests { let config: Arc = Arc::new(Default::default()); let (schema, query) = parse_schema_and_operation(schema_str, query_str, &config); - let mut planner = BridgeQueryPlanner::new(schema.into(), config.clone(), None, None) - .await - .unwrap(); + let planner = + QueryPlanner::new(schema.federation_supergraph(), Default::default()).unwrap(); - let ctx = Context::new(); - ctx.extensions() - .with_lock(|mut lock| lock.insert::(query)); + let query_plan = planner.build_query_plan(&query.executable, None).unwrap(); - let planner_res = planner - .call(QueryPlannerRequest::new(query_str.to_string(), None, ctx)) - .await - .unwrap(); - let query_plan = match planner_res.content.unwrap() { - QueryPlannerContent::Plan { plan } => plan, - _ => panic!("Query planner returned unexpected non-plan content"), - }; + let schema = + DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(); + let mut demand_controlled_subgraph_schemas = HashMap::new(); + for (subgraph_name, subgraph_schema) in planner.subgraph_schemas().iter() { + let demand_controlled_subgraph_schema = + DemandControlledSchema::new(Arc::new(subgraph_schema.schema().clone())).unwrap(); + demand_controlled_subgraph_schemas + .insert(subgraph_name.to_string(), demand_controlled_subgraph_schema); + } - let calculator = StaticCostCalculator { - subgraph_schemas: planner.subgraph_schemas(), - list_size: 100, - }; + let calculator = StaticCostCalculator::new( + Arc::new(schema), + Arc::new(demand_controlled_subgraph_schemas), + 100, + ); - calculator.planned(&query_plan).unwrap() + calculator.rust_planned(&query_plan).unwrap() } fn actual_cost(schema_str: &str, query_str: &str, response_bytes: &'static [u8]) -> f64 { - let (_schema, query) = + let (schema, query) = parse_schema_and_operation(schema_str, query_str, &Default::default()); let response = Response::from_bytes("test", Bytes::from(response_bytes)).unwrap(); - StaticCostCalculator::new(Default::default(), 100) + let schema = + DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(); + StaticCostCalculator::new(Arc::new(schema), Default::default(), 100) .actual(&query.executable, &response) .unwrap() } + /// Actual cost of an operation on a plain, non-federated schema. + fn basic_actual_cost(schema_str: &str, query_str: &str, response_bytes: &'static [u8]) -> f64 { + let schema = + apollo_compiler::Schema::parse_and_validate(schema_str, "schema.graphqls").unwrap(); + let query = apollo_compiler::ExecutableDocument::parse_and_validate( + &schema, + query_str, + "query.graphql", + ) + .unwrap(); + let response = Response::from_bytes("test", Bytes::from(response_bytes)).unwrap(); + + let schema = DemandControlledSchema::new(Arc::new(schema)).unwrap(); + StaticCostCalculator::new(Arc::new(schema), Default::default(), 100) + .actual(&query, &response) + .unwrap() + } + #[test] fn query_cost() { let schema = include_str!("./fixtures/basic_schema.graphql"); @@ -606,7 +751,18 @@ mod tests { let schema = include_str!("./fixtures/basic_schema.graphql"); let query = include_str!("./fixtures/basic_input_object_query.graphql"); - assert_eq!(basic_estimated_cost(schema, query), 2.0) + assert_eq!(basic_estimated_cost(schema, query), 4.0) + } + + #[test] + fn input_object_cost_with_returned_objects() { + let schema = include_str!("./fixtures/basic_schema.graphql"); + let query = include_str!("./fixtures/basic_input_object_query_2.graphql"); + let response = include_bytes!("./fixtures/basic_input_object_response.json"); + + assert_eq!(basic_estimated_cost(schema, query), 104.0); + // The cost of the arguments from the query should be included when scoring the response + assert_eq!(basic_actual_cost(schema, query, response), 7.0); } #[test] @@ -684,15 +840,55 @@ mod tests { let schema = include_str!("./fixtures/federated_ships_schema.graphql"); let query = include_str!("./fixtures/federated_ships_deferred_query.graphql"); let (schema, query) = parse_schema_and_operation(schema, query, &Default::default()); + let schema = Arc::new( + DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(), + ); - let conservative_estimate = StaticCostCalculator::new(Default::default(), 100) - .estimated(&query.executable, schema.supergraph_schema(), true) + let calculator = StaticCostCalculator::new(schema.clone(), Default::default(), 100); + let conservative_estimate = calculator + .estimated(&query.executable, &calculator.supergraph_schema, true) .unwrap(); - let narrow_estimate = StaticCostCalculator::new(Default::default(), 5) - .estimated(&query.executable, schema.supergraph_schema(), true) + + let calculator = StaticCostCalculator::new(schema.clone(), Default::default(), 5); + let narrow_estimate = calculator + .estimated(&query.executable, &calculator.supergraph_schema, true) .unwrap(); assert_eq!(conservative_estimate, 10200.0); assert_eq!(narrow_estimate, 35.0); } + + #[test(tokio::test)] + async fn custom_cost_query() { + let schema = include_str!("./fixtures/custom_cost_schema.graphql"); + let query = include_str!("./fixtures/custom_cost_query.graphql"); + let response = include_bytes!("./fixtures/custom_cost_response.json"); + + assert_eq!(estimated_cost(schema, query), 127.0); + assert_eq!(planned_cost(schema, query).await, 127.0); + assert_eq!(actual_cost(schema, query, response), 125.0); + } + + #[test(tokio::test)] + async fn custom_cost_query_with_renamed_directives() { + let schema = include_str!("./fixtures/custom_cost_schema_with_renamed_directives.graphql"); + let query = include_str!("./fixtures/custom_cost_query.graphql"); + let response = include_bytes!("./fixtures/custom_cost_response.json"); + + assert_eq!(estimated_cost(schema, query), 127.0); + assert_eq!(planned_cost(schema, query).await, 127.0); + assert_eq!(actual_cost(schema, query, response), 125.0); + } + + #[test(tokio::test)] + async fn custom_cost_query_with_default_slicing_argument() { + let schema = include_str!("./fixtures/custom_cost_schema.graphql"); + let query = + include_str!("./fixtures/custom_cost_query_with_default_slicing_argument.graphql"); + let response = include_bytes!("./fixtures/custom_cost_response.json"); + + assert_eq!(estimated_cost(schema, query), 132.0); + assert_eq!(planned_cost(schema, query).await, 132.0); + assert_eq!(actual_cost(schema, query, response), 125.0); + } } diff --git a/apollo-router/src/plugins/demand_control/mod.rs b/apollo-router/src/plugins/demand_control/mod.rs index 476deeb737..bf0cdf5f26 100644 --- a/apollo-router/src/plugins/demand_control/mod.rs +++ b/apollo-router/src/plugins/demand_control/mod.rs @@ -5,6 +5,9 @@ use std::future; use std::ops::ControlFlow; use std::sync::Arc; +use ahash::HashMap; +use ahash::HashMapExt; +use apollo_compiler::schema::FieldLookupError; use apollo_compiler::validation::Valid; use apollo_compiler::validation::WithErrors; use apollo_compiler::ExecutableDocument; @@ -27,6 +30,7 @@ use crate::json_ext::Object; use crate::layers::ServiceBuilderExt; use crate::plugin::Plugin; use crate::plugin::PluginInit; +use crate::plugins::demand_control::cost_calculator::schema::DemandControlledSchema; use crate::plugins::demand_control::strategy::Strategy; use crate::plugins::demand_control::strategy::StrategyFactory; use crate::register_plugin; @@ -199,6 +203,22 @@ impl From> for DemandControlError { } } +impl<'a> From> for DemandControlError { + fn from(value: FieldLookupError) -> Self { + match value { + FieldLookupError::NoSuchType => DemandControlError::QueryParseFailure( + "Attempted to look up a type which does not exist in the schema".to_string(), + ), + FieldLookupError::NoSuchField(type_name, _) => { + DemandControlError::QueryParseFailure(format!( + "Attempted to look up a field on type {}, but the field does not exist", + type_name + )) + } + } + } +} + pub(crate) struct DemandControl { config: DemandControlConfig, strategy_factory: StrategyFactory, @@ -223,11 +243,21 @@ impl Plugin for DemandControl { type Config = DemandControlConfig; async fn new(init: PluginInit) -> Result { + let demand_controlled_supergraph_schema = + DemandControlledSchema::new(init.supergraph_schema.clone())?; + let mut demand_controlled_subgraph_schemas = HashMap::new(); + for (subgraph_name, subgraph_schema) in init.subgraph_schemas.iter() { + let demand_controlled_subgraph_schema = + DemandControlledSchema::new(subgraph_schema.clone())?; + demand_controlled_subgraph_schemas + .insert(subgraph_name.clone(), demand_controlled_subgraph_schema); + } + Ok(DemandControl { strategy_factory: StrategyFactory::new( init.config.clone(), - init.supergraph_schema.clone(), - init.subgraph_schemas.clone(), + Arc::new(demand_controlled_supergraph_schema), + Arc::new(demand_controlled_subgraph_schemas), ), config: init.config, }) diff --git a/apollo-router/src/plugins/demand_control/strategy/mod.rs b/apollo-router/src/plugins/demand_control/strategy/mod.rs index 5defca64d5..6bae126694 100644 --- a/apollo-router/src/plugins/demand_control/strategy/mod.rs +++ b/apollo-router/src/plugins/demand_control/strategy/mod.rs @@ -1,11 +1,10 @@ -use std::collections::HashMap; use std::sync::Arc; -use apollo_compiler::validation::Valid; +use ahash::HashMap; use apollo_compiler::ExecutableDocument; -use apollo_compiler::Schema; use crate::graphql; +use crate::plugins::demand_control::cost_calculator::schema::DemandControlledSchema; use crate::plugins::demand_control::cost_calculator::static_cost::StaticCostCalculator; use crate::plugins::demand_control::strategy::static_estimated::StaticEstimated; use crate::plugins::demand_control::DemandControlConfig; @@ -75,15 +74,15 @@ impl Strategy { pub(crate) struct StrategyFactory { config: DemandControlConfig, #[allow(dead_code)] - supergraph_schema: Arc>, - subgraph_schemas: Arc>>>, + supergraph_schema: Arc, + subgraph_schemas: Arc>, } impl StrategyFactory { pub(crate) fn new( config: DemandControlConfig, - supergraph_schema: Arc>, - subgraph_schemas: Arc>>>, + supergraph_schema: Arc, + subgraph_schemas: Arc>, ) -> Self { Self { config, @@ -97,6 +96,7 @@ impl StrategyFactory { StrategyConfig::StaticEstimated { list_size, max } => Arc::new(StaticEstimated { max: *max, cost_calculator: StaticCostCalculator::new( + self.supergraph_schema.clone(), self.subgraph_schemas.clone(), *list_size, ),