diff --git a/spec/compiler/semantic/module_spec.cr b/spec/compiler/semantic/module_spec.cr index 56252bc36585..a9868172c057 100644 --- a/spec/compiler/semantic/module_spec.cr +++ b/spec/compiler/semantic/module_spec.cr @@ -477,6 +477,67 @@ describe "Semantic: module" do ", "cyclic include detected" end + it "gives error when including self, generic module" do + assert_error " + module Foo(T) + include self + end + ", "cyclic include detected" + end + + it "gives error when including instantiation of self, generic module" do + assert_error " + module Foo(T) + include Foo(Int32) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include, generic module" do + assert_error " + module Foo(T) + end + + module Bar(T) + include Foo(T) + end + + module Foo(T) + include Bar(T) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include between non-generic and generic module" do + assert_error " + module Foo + end + + module Bar(T) + include Foo + end + + module Foo + include Bar(Int32) + end + ", "cyclic include detected" + end + + it "gives error with cyclic include between non-generic and generic module (2)" do + assert_error " + module Bar(T) + end + + module Foo + include Bar(Int32) + end + + module Bar(T) + include Foo + end + ", "cyclic include detected" + end + it "finds types close to included module" do assert_type(" module Foo diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr index a527208a4bbf..5922c8ae71f9 100644 --- a/src/compiler/crystal/types.cr +++ b/src/compiler/crystal/types.cr @@ -944,16 +944,22 @@ module Crystal end def include(mod) - if mod == self - raise TypeException.new "cyclic include detected" - elsif mod.ancestors.includes?(self) + generic_module = mod.is_a?(GenericModuleInstanceType) ? mod.generic_type : mod + + if generic_module == self raise TypeException.new "cyclic include detected" - else - unless parents.includes?(mod) - parents.insert 0, mod - mod.add_including_type(self) + end + + generic_module.ancestors.each do |ancestor| + if ancestor == self || ancestor.is_a?(GenericModuleInstanceType) && ancestor.generic_type == self + raise TypeException.new "cyclic include detected" end end + + unless parents.includes?(mod) + parents.insert 0, mod + mod.add_including_type(self) + end end def type_desc