diff --git a/src/mkdocs_bibtex/config.py b/src/mkdocs_bibtex/config.py index 8d441fc..c491a6e 100644 --- a/src/mkdocs_bibtex/config.py +++ b/src/mkdocs_bibtex/config.py @@ -29,4 +29,4 @@ class BibTexConfig(base.Config): # Settings bib_by_default = config_options.Type(bool, default=True) - footnote_format = config_options.Type(str, default="{number}") + footnote_format = config_options.Type(str, default="{key}") diff --git a/src/mkdocs_bibtex/plugin.py b/src/mkdocs_bibtex/plugin.py index 7a0e12e..09d3898 100644 --- a/src/mkdocs_bibtex/plugin.py +++ b/src/mkdocs_bibtex/plugin.py @@ -68,13 +68,15 @@ def on_config(self, config): else: self.csl_file = self.config.csl_file - if "{number}" not in self.config.footnote_format: - raise ConfigurationError("Must include `{number}` placeholder in footnote_format") + if "{key}" not in self.config.footnote_format: + raise ConfigurationError("Must include `{key}` placeholder in footnote_format") if self.csl_file: - self.registry = PandocRegistry(bib_files=bibfiles, csl_file=self.csl_file) + self.registry = PandocRegistry( + bib_files=bibfiles, csl_file=self.csl_file, footnote_format=self.config.footnote_format + ) else: - self.registry = SimpleRegistry(bib_files=bibfiles) + self.registry = SimpleRegistry(bib_files=bibfiles, footnote_format=self.config.footnote_format) self.last_configured = time.time() return config @@ -121,7 +123,11 @@ def on_page_markdown(self, markdown, page, config, files): bibliography = [] for citation in citations.values(): try: - bibliography.append("[^{}]: {}".format(citation.key, self.registry.reference_text(citation))) + bibliography.append( + "[^{}]: {}".format( + self.registry.footnote_format.format(key=citation.key), self.registry.reference_text(citation) + ) + ) except Exception as e: log.warning(f"Error formatting citation {citation.key}: {e}") bibliography = "\n".join(bibliography) @@ -136,7 +142,11 @@ def on_page_markdown(self, markdown, page, config, files): self.registry.validate_citation_blocks(blocks) full_bibliography = [] for citation in all_citations: - full_bibliography.append("[^{}]: {}".format(citation.key, self.registry.reference_text(citation))) + full_bibliography.append( + "[^{}]: {}".format( + self.registry.footnote_format.format(key=citation.key), self.registry.reference_text(citation) + ) + ) full_bibliography = "\n".join(full_bibliography) markdown = markdown.replace(full_bib_command, full_bibliography) diff --git a/src/mkdocs_bibtex/registry.py b/src/mkdocs_bibtex/registry.py index 5f82b19..915fd74 100644 --- a/src/mkdocs_bibtex/registry.py +++ b/src/mkdocs_bibtex/registry.py @@ -15,7 +15,7 @@ class ReferenceRegistry(ABC): A registry of references that can be used to format citations """ - def __init__(self, bib_files: list[str]): + def __init__(self, bib_files: list[str], footnote_format: str = "{key}"): refs = {} log.info(f"Loading data from bib files: {bib_files}") for bibfile in bib_files: @@ -23,6 +23,7 @@ def __init__(self, bib_files: list[str]): bibdata = parse_file(bibfile) refs.update(bibdata.entries) self.bib_data = BibliographyData(entries=refs) + self.footnote_format = footnote_format @abstractmethod def validate_citation_blocks(self, citation_blocks: list[CitationBlock]) -> None: @@ -38,8 +39,8 @@ def reference_text(self, citation: Citation) -> str: class SimpleRegistry(ReferenceRegistry): - def __init__(self, bib_files: list[str]): - super().__init__(bib_files) + def __init__(self, bib_files: list[str], footnote_format: str = "{key}"): + super().__init__(bib_files, footnote_format) self.style = PlainStyle() self.backend = MarkdownBackend() @@ -56,7 +57,11 @@ def validate_citation_blocks(self, citation_blocks: list[CitationBlock]) -> None log.warning(f"Affixes not supported in simple mode: {citation}") def inline_text(self, citation_block: CitationBlock) -> str: - keys = [citation.key for citation in citation_block.citations if citation.key in self.bib_data.entries] + keys = [ + self.footnote_format.format(key=citation.key) + for citation in citation_block.citations + if citation.key in self.bib_data.entries + ] return "".join(f"[^{key}]" for key in keys) def reference_text(self, citation: Citation) -> str: @@ -74,8 +79,8 @@ def reference_text(self, citation: Citation) -> str: class PandocRegistry(ReferenceRegistry): """A registry that uses Pandoc to format citations""" - def __init__(self, bib_files: list[str], csl_file: str): - super().__init__(bib_files) + def __init__(self, bib_files: list[str], csl_file: str, footnote_format: str = "{key}"): + super().__init__(bib_files, footnote_format) self.csl_file = csl_file # Get pandoc version for formatting decisions @@ -91,7 +96,9 @@ def __init__(self, bib_files: list[str], csl_file: str): def inline_text(self, citation_block: CitationBlock) -> str: """Get the inline text for a citation block""" footnotes = " ".join( - f"[^{citation.key}]" for citation in citation_block.citations if citation.key in self._reference_cache + f"[^{self.footnote_format.format(key=citation.key)}]" + for citation in citation_block.citations + if citation.key in self._reference_cache ) if self._is_inline: diff --git a/test_files/test_integration.py b/test_files/test_integration.py index aae508c..e46cb60 100644 --- a/test_files/test_integration.py +++ b/test_files/test_integration.py @@ -120,14 +120,22 @@ def test_bibliography_controls(plugin): assert "[^test]:" in result -@pytest.mark.xfail(reason="Need to reimplement footnote formatting") -def test_custom_footnote_format(plugin): +def test_custom_footnote_format(): """Test custom footnote formatting""" - plugin.config.footnote_format = "ref{number}" + plugin = BibTexPlugin() + plugin.load_config( + options={ + "bib_file": os.path.join(test_files_dir, "test.bib"), + "bib_by_default": False, + "footnote_format": "ref-{key}", + }, + config_file_path=test_files_dir, + ) + plugin.on_config(plugin.config) + markdown = "Citation [@test]\n\n\\bibliography" result = plugin.on_page_markdown(markdown, None, None, None) - assert "[^reftest]" in result - + assert "[^ref-test]" in result # Test that an invalid footnote format raises an exception bad_plugin = BibTexPlugin() bad_plugin.load_config(