From 4b62e90c1049f2ecf2792a59ee21dc76b53de720 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Mon, 22 Mar 2021 17:16:37 +0800 Subject: [PATCH] Detect cyclic includes between generic modules --- spec/compiler/semantic/module_spec.cr | 61 +++++++++++++++++++++++++++ src/compiler/crystal/types.cr | 20 ++++++--- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/spec/compiler/semantic/module_spec.cr b/spec/compiler/semantic/module_spec.cr index 457f00fe6a4a..2853f6936775 100644 --- a/spec/compiler/semantic/module_spec.cr +++ b/spec/compiler/semantic/module_spec.cr @@ -477,6 +477,67 @@ describe "Semantic: module" do ", "cyclic include detected" end + it "gives error when including self, generic module" do + assert_error " + module Foo(T) + include self + end + ", "cyclic include detected" + end + + it "gives error when including instantiation of self, generic module" do + assert_error " + module Foo(T) + include Foo(Int32) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include, generic module" do + assert_error " + module Foo(T) + end + + module Bar(T) + include Foo(T) + end + + module Foo(T) + include Bar(T) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include between non-generic and generic module" do + assert_error " + module Foo + end + + module Bar(T) + include Foo + end + + module Foo + include Bar(Int32) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include between non-generic and generic module (2)" do + assert_error " + module Bar(T) + end + + module Foo + include Bar(Int32) + end + + module Bar(T) + include Foo + end + ", "cyclic include detected" + end + it "finds types close to included module" do assert_type(" module Foo diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr index 11494bdaf7d4..c427e937217c 100644 --- a/src/compiler/crystal/types.cr +++ b/src/compiler/crystal/types.cr @@ -959,16 +959,22 @@ module Crystal end def include(mod) - if mod == self - raise TypeException.new "cyclic include detected" - elsif mod.ancestors.includes?(self) + generic_module = mod.is_a?(GenericModuleInstanceType) ? mod.generic_type : mod + + if generic_module == self raise TypeException.new "cyclic include detected" - else - unless parents.includes?(mod) - parents.insert 0, mod - mod.add_including_type(self) + end + + generic_module.ancestors.each do |ancestor| + if ancestor == self || ancestor.is_a?(GenericModuleInstanceType) && ancestor.generic_type == self + raise TypeException.new "cyclic include detected" end end + + unless parents.includes?(mod) + parents.insert 0, mod + mod.add_including_type(self) + end end def covariant?(other_type)