Skip to content

Commit

Permalink
linting with new rules
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Jul 8, 2024
1 parent 3b3f0bf commit 8e6b985
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/chainconsumer/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _get_2d_latex_table(self, named_matrix: Named2DMatrix, caption: str, label:
for p, row in zip(parameters, matrix):
table += format_string % p
for r in row:
table += " & %5.2f" % r
table += f" & {r:5.2f}"
table += " \\\\ \n"
table += hline_text
return latex_table % (column_def, table)
Expand Down Expand Up @@ -396,7 +396,7 @@ def get_parameter_text(self, bound: Bound, wrap: bool = False):
if factor != 0:
text = r"\left( %s \right) \times 10^{%d}" % (text, -factor)
if wrap:
text = "$%s$" % text
text = f"${text}$"
return text

def get_parameter_summary_mean(self, chain: Chain, column: ColumnName) -> Bound | None:
Expand Down
2 changes: 1 addition & 1 deletion src/chainconsumer/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def dic(self) -> dict[str, float]:
for name, chain in self.parent._chains.items():
p = chain.log_posterior
if p is None:
logger.warning("You need to set the posterior for chain %s to get the DIC" % chain.name)
logger.warning(f"You need to set the posterior for chain {chain.name} to get the DIC")
else:
means = np.array([np.average(chain.samples[c], weights=chain.weights) for c in chain.data_columns])
d = -2 * p
Expand Down
4 changes: 2 additions & 2 deletions src/chainconsumer/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _sanitise_chains(

class Diagnostic:
def __init__(self, parent: "ChainConsumer"):
self.parent: "ChainConsumer" = parent
self.parent: ChainConsumer = parent

def gelman_rubin(
self, chains: list[Chain | ChainName] | Chain | ChainName | None = None, threshold: float = 0.05
Expand Down Expand Up @@ -87,7 +87,7 @@ def gelman_rubin(
r: float = np.sqrt(v / w)

passed = np.abs(r - 1) < threshold
logger.info("Gelman-Rubin Statistic values for chain %s" % name)
logger.info(f"Gelman-Rubin Statistic values for chain {name}")
for p, v, pas in zip(parameters, r, passed):
param = "Param %d" % p if isinstance(p, int) else p
logger.info(f"{param}: {v:7.5f} ({'Passed' if pas else 'Failed'})")
Expand Down
4 changes: 2 additions & 2 deletions src/chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_artists_from_chains(chains: list[Chain]) -> list[Artist]:

class Plotter:
def __init__(self, parent: "ChainConsumer") -> None:
self.parent: "ChainConsumer" = parent
self.parent: ChainConsumer = parent
self._config: PlotConfig | None = None
self._default_config = PlotConfig()

Expand Down Expand Up @@ -983,7 +983,7 @@ def _plot_bars(
r"${} = {}$".format(label.strip("$"), t), fontsize=self.config.summary_font_size
)
else:
ax.set_title(r"$%s$" % t, fontsize=self.config.summary_font_size)
ax.set_title(rf"${t}$", fontsize=self.config.summary_font_size)
return float(ys.max())

def _plot_walk(
Expand Down

0 comments on commit 8e6b985

Please sign in to comment.