Skip to content

Commit

Permalink
Merge pull request #51 from datamol-io/chore-maintain
Browse files Browse the repository at this point in the history
Chores: doc update
  • Loading branch information
maclandrol authored Jul 25, 2024
2 parents d4610e7 + 39349c4 commit 9413f0a
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions safe/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def __init__(
):
"""SAFEDesign constructor
!!! info
Design methods in SAFE are not deterministic when it comes to the token sampling step.
If a method accepts a `random_seed`, it's for the SAFE-related algorithms and not the
sampling from the autoregressive model. To ensure you get a deterministic sampling,
please set the seed at the `transformers` package level.
```python
import safe as sf
import transformers
my_seed = 100
designer = sf.SAFEDesign(...)
transformers.set_seed(100) # use this before calling a design function
designer.linker_generation(...)
```
Args:
model: input SAFEDoubleHeadsModel to use for generation
tokenizer: input SAFETokenizer to use for generation
Expand Down Expand Up @@ -269,7 +286,6 @@ def _fragment_linking(
self.safe_encoder.slicer = old_slicer

fragments = encoded_fragment.split(".")

missing_closure = Counter(self.safe_encoder._find_branch_number(encoded_fragment))
missing_closure = [f"{str(x)}" for x in missing_closure if missing_closure[x] % 2 == 1]

Expand Down Expand Up @@ -401,7 +417,7 @@ def motif_extension(
n_samples_per_trial: int = 10,
n_trials: Optional[int] = 1,
sanitize: bool = False,
do_not_fragment_further: Optional[bool] = False,
do_not_fragment_further: Optional[bool] = True,
random_seed: Optional[int] = None,
**kwargs,
):
Expand Down Expand Up @@ -434,7 +450,7 @@ def super_structure(
n_samples_per_trial: int = 10,
n_trials: Optional[int] = 1,
sanitize: bool = False,
do_not_fragment_further: Optional[bool] = False,
do_not_fragment_further: Optional[bool] = True,
random_seed: Optional[int] = None,
attachment_point_depth: Optional[int] = None,
**kwargs,
Expand Down Expand Up @@ -499,7 +515,7 @@ def scaffold_decoration(
scaffold: Union[str, dm.Mol],
n_samples_per_trial: int = 10,
n_trials: Optional[int] = 1,
do_not_fragment_further: Optional[bool] = False,
do_not_fragment_further: Optional[bool] = True,
sanitize: bool = False,
random_seed: Optional[int] = None,
add_dot=True,
Expand Down Expand Up @@ -855,14 +871,16 @@ def _generate(
else:
# EN: we remove the EOS token added before running the prediction
# because the model output nonsense when we keep it.
# I don't know why it works for text generation but not here
for k in input_ids:
input_ids[k] = input_ids[k][:, :-1]

for k, v in input_ids.items():
if torch.is_tensor(v):
input_ids[k] = v.to(self.model.device)

# we remove the token_type_ids to support more model type than just GPT2
input_ids.pop("token_type_ids", None)

if is_greedy:
kwargs["num_return_sequences"] = 1
if num_beams is not None and num_beams > 1:
Expand Down

0 comments on commit 9413f0a

Please sign in to comment.