diff --git a/src/NuDoq.Tests/ElementTests.cs b/src/NuDoq.Tests/ElementTests.cs new file mode 100644 index 0000000..0ee9865 --- /dev/null +++ b/src/NuDoq.Tests/ElementTests.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using NuDoq; +using Xunit; +using Xunit.Abstractions; + +namespace NuDoq +{ + public class ElementTests + { + readonly ITestOutputHelper output; + + public ElementTests(ITestOutputHelper output) => this.output = output; + + [Fact] + public void when_enumerating_elements_then_can_list_twice_enumerates_once() + { + var enumerations = 0; + + IEnumerable GetElements() + { + yield return new Text("foo"); + yield return new Text("bar"); + enumerations++; + }; + + var element = new Summary(GetElements(), new Dictionary()); + + var first = element.Elements.Count(); + var second = element.Elements.Count(); + + Assert.Contains(element.Elements.OfType(), e => e.Content == "foo"); + Assert.Contains(element.Elements.OfType(), e => e.Content == "foo"); + Assert.Contains(element.Elements.OfType(), e => e.Content == "bar"); + Assert.Contains(element.Elements.OfType(), e => e.Content == "bar"); + } + + /// + /// + /// foo + /// bar + /// baz + /// + /// + [Fact] + public void when_enumerating_elements_then_visit_twice() + { + var member = DocReader.Read(Assembly.GetExecutingAssembly()); + var method = member.Elements.OfType() + .Where(x => x.Elements.OfType().Count() > 2) + .FirstOrDefault(); + + Assert.NotNull(method); + + List? currentList = null; + var count = 0; + + var visitor = new DelegateVisitor(new VisitorDelegates + { + VisitList = (List list) => currentList = list, + VisitItem = (Item item) => { output.WriteLine(currentList.Elements.Count().ToString()); count++; }, + }); + + member.Accept(visitor); + + Assert.Equal(3, count); + } + } +} \ No newline at end of file diff --git a/src/NuDoq/CachedEnumerable.cs b/src/NuDoq/CachedEnumerable.cs index b6f3122..8881362 100644 --- a/src/NuDoq/CachedEnumerable.cs +++ b/src/NuDoq/CachedEnumerable.cs @@ -13,34 +13,12 @@ static class CachedEnumerable class CachedEnumerableImpl : IEnumerable { - IEnumerator? enumerator; readonly IEnumerable enumerable; - readonly List cache = new List(); + List? cache; public CachedEnumerableImpl(IEnumerable enumerable) => this.enumerable = enumerable; - public IEnumerator GetEnumerator() - { - // First time around, there will be nothing in - // this cache. - foreach (var item in cache) - yield return item; - - // First time we'll get the enumerator, only - // once. Next time, it will already have a value - // and so we won't enumerate twice ever. - if (enumerator == null) - enumerator = enumerable.GetEnumerator(); - - // First time around, we'll loop until we're done. - // Next time it's enumerated, this enumerator will - // return false from MoveNext right-away. - while (enumerator.MoveNext()) - { - cache.Add(enumerator.Current); - yield return enumerator.Current; - } - } + public IEnumerator GetEnumerator() => (cache ??= new List(enumerable)).GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); }