Skip to content

Commit

Permalink
renaming batch params to be more specific
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 22, 2023
1 parent 53649e3 commit b6c70b4
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions linear_relational/training/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def train_lre(
prompts: list[Prompt],
object_aggregation: ObjectAggregation = "mean",
validate_prompts: bool = True,
batch_size: int = 8,
validate_prompts_batch_size: int = 4,
move_to_cpu: bool = False,
verbose: bool = True,
) -> Lre:
processed_prompts = self._process_relation_prompts(
relation=relation,
prompts=prompts,
validate_prompts=validate_prompts,
batch_size=batch_size,
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
)
return train_lre(
Expand All @@ -85,7 +85,7 @@ def train_relation_concepts(
object_aggregation: ObjectAggregation = "mean",
vector_aggregation: VectorAggregation = "post_mean",
inv_lre_rank: int = 200,
batch_size: int = 8,
validate_prompts_batch_size: int = 4,
validate_prompts: bool = True,
verbose: bool = True,
name_concept_fn: Optional[Callable[[str, str], str]] = None,
Expand All @@ -95,7 +95,7 @@ def train_relation_concepts(
relation=relation,
prompts=prompts,
validate_prompts=validate_prompts,
batch_size=batch_size,
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
)
prompts_by_object = group_items(processed_prompts, lambda p: p.object_name)
Expand All @@ -115,15 +115,15 @@ def train_relation_concepts(
prompts=lre_train_prompts,
object_aggregation=object_aggregation,
validate_prompts=False, # we already validated the prompts above
batch_size=batch_size,
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
).invert(inv_lre_rank)

return self.train_relation_concepts_from_inv_lre(
inv_lre=inv_lre,
prompts=processed_prompts,
vector_aggregation=vector_aggregation,
batch_size=batch_size,
validate_prompts_batch_size=validate_prompts_batch_size,
validate_prompts=False, # we already validated the prompts above
name_concept_fn=name_concept_fn,
verbose=verbose,
Expand All @@ -134,7 +134,8 @@ def train_relation_concepts_from_inv_lre(
inv_lre: InvertedLre,
prompts: list[Prompt],
vector_aggregation: VectorAggregation = "post_mean",
batch_size: int = 8,
validate_prompts_batch_size: int = 4,
extract_objects_batch_size: int = 4,
validate_prompts: bool = True,
name_concept_fn: Optional[Callable[[str, str], str]] = None,
verbose: bool = True,
Expand All @@ -144,13 +145,13 @@ def train_relation_concepts_from_inv_lre(
relation=relation,
prompts=prompts,
validate_prompts=validate_prompts,
batch_size=batch_size,
validate_prompts_batch_size=validate_prompts_batch_size,
verbose=verbose,
)
start_time = time()
object_activations = self._extract_target_object_activations_for_inv_lre(
prompts=processed_prompts,
batch_size=batch_size,
batch_size=extract_objects_batch_size,
object_aggregation=inv_lre.object_aggregation,
object_layer=inv_lre.object_layer,
show_progress=verbose,
Expand Down Expand Up @@ -186,14 +187,14 @@ def _process_relation_prompts(
relation: str,
prompts: list[Prompt],
validate_prompts: bool,
batch_size: int,
validate_prompts_batch_size: int,
verbose: bool,
) -> list[Prompt]:
valid_prompts = prompts
if validate_prompts:
log_or_print(f"validating {len(prompts)} prompts", verbose=verbose)
valid_prompts = self.prompt_validator.filter_prompts(
prompts, batch_size, verbose
prompts, validate_prompts_batch_size, verbose
)
if len(valid_prompts) == 0:
raise ValueError(f"No valid prompts found for {relation}.")
Expand Down

0 comments on commit b6c70b4

Please sign in to comment.