Skip to content

Commit

Permalink
Merge pull request advimman#112 from geomagical/refinement
Browse files Browse the repository at this point in the history
Feature Refinement to Improve High Resolution Image Inpainting
  • Loading branch information
senya-ashukha authored Jul 28, 2022
2 parents 3852b47 + fa9725e commit b1ff47f
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 24 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ bash docker/2_predict.sh $(pwd)/big-lama $(pwd)/LaMa_test_images $(pwd)/output d
```
Docker cuda: TODO
**4. Predict with Refinement**
On the host machine:
python3 bin/predict.py refine=True model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output
# Train and Eval
⚠️ Warning: The training is not fully tested yet, e.g., did not re-training after refactoring ⚠️
Expand Down
56 changes: 32 additions & 24 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import traceback

from saicinpainting.evaluation.utils import move_to_device

from saicinpainting.evaluation.refinement import refine_predict
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
Expand Down Expand Up @@ -56,34 +56,42 @@ def main(predict_config: OmegaConf):
predict_config.model.checkpoint)
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.freeze()
model.to(device)
if not predict_config.get('refine', False):
model.to(device)

if not predict_config.indir.endswith('/'):
predict_config.indir += '/'

dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
with torch.no_grad():
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)

batch = move_to_device(default_collate([dataset[img_i]]), device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()

unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
batch = default_collate([dataset[img_i]])
if predict_config.get('refine', False):
assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
# image unpadding is taken care of in the refiner, so that output image
# is same size as the input image
cur_res = refine_predict(batch, model, **predict_config.refiner)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)

except KeyboardInterrupt:
LOGGER.warning('Interrupted by user')
except Exception as ex:
Expand Down
10 changes: 10 additions & 0 deletions configs/prediction/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ dataset:

device: cuda
out_key: inpainted

refine: False # refiner will only run if this is True
refiner:
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
modulo: ${dataset.pad_out_to_modulo}
n_iters: 15 # number of iterations of refinement for each scale
lr: 0.002 # learning rate
min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2)
max_scales: 3 # max number of downscaling scales for the image-mask pyramid
px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget
Loading

0 comments on commit b1ff47f

Please sign in to comment.