Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get output images #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ If you use this pre-trained weights of model, you can
cd ./src
python main.py --skip_training --RESUME=True --gpu_id=0
```
To test on input and masks in a folder and get output images
```bash
cd ./src
python main.py --skip_training --RESUME=True --gpu_id=0 --test --image ../examples/places2 --mask ../examples/places2_masks --output ../examples/places2_results
```


## Some Details
Expand All @@ -68,4 +73,4 @@ If you use it for your research, please cite our paper [Progressive Image Inpain
year={2019},
organization={ACM}
}
```
```
Binary file added examples/places2/building.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2/canyon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2/grass.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2/sunset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2/wooden.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_masks/building.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_masks/canyon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_masks/grass.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_masks/sunset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_masks/wooden.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_results/building.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_results/canyon.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_results/grass.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_results/sunset.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/places2_results/wooden.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 73 additions & 12 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr

import pdb
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F

parser = argparse.ArgumentParser(description='Image Inpainting')
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--max_iterations', type=int, default=500000, help="max iteration number in one epoch")
Expand All @@ -23,7 +28,7 @@
parser.add_argument('--TEST_FLIST', type=str, default='./flist/places2_val.flist')
parser.add_argument('--TEST_MASK_FLIST', type=str, default='./flist/masks_30to40.flist')

parser.add_argument('--save_model_dir', type=str, default='./save_models')
parser.add_argument('--save_model_dir', type=str, default='../save_models')
parser.add_argument('--save_iter_interval', type=int, default=100000, help="interval for saving model")

parser.add_argument('--LR', type=float, default=0.0002, help='learning rate (default: 0.0002)')
Expand All @@ -38,6 +43,10 @@
parser.add_argument('--RESUME', default=False, type=bool, help='load pre-trained weights')
parser.add_argument('--skip_training', default=False, action='store_true')
parser.add_argument('--skip_validation', default=False, action='store_true')
parser.add_argument('--image', type=str, help='path to input images')
parser.add_argument('--mask', type=str, help='path to input png hole masks')
parser.add_argument('--output', type=str, help='path to save jpg results')
parser.add_argument('--test', default=False, action='store_true')

config = parser.parse_args()

Expand All @@ -52,12 +61,13 @@

inpaint_model = InpaintingModel(config).to(DEVICE)

if not config.skip_training:
train_set = Dataset(config, config.TRAIN_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True)
train_loader = DataLoader(dataset=train_set, batch_size=config.batch_size, num_workers=config.num_workers, drop_last=True, shuffle=True)
if not config.skip_validation:
eval_set = Dataset(config, config.TEST_FLIST, config.TEST_MASK_FLIST, augment=True, training=True)
eval_loader = DataLoader(dataset=eval_set, batch_size=1, shuffle=True)
if not config.test:
if not config.skip_training:
train_set = Dataset(config, config.TRAIN_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True)
train_loader = DataLoader(dataset=train_set, batch_size=config.batch_size, num_workers=config.num_workers, drop_last=True, shuffle=True)
if not config.skip_validation:
eval_set = Dataset(config, config.TEST_FLIST, config.TEST_MASK_FLIST, augment=True, training=True)
eval_loader = DataLoader(dataset=eval_set, batch_size=1, shuffle=True)


def train():
Expand Down Expand Up @@ -97,15 +107,66 @@ def eval():
print('[EVAL] ({}/{}) PSNR:{:.4f}, SSIM:{:.4f}, L1:{:.4f}'.
format(eval_iter, len(eval_loader),log[0] / eval_iter, log[1] / eval_iter, log[2] / eval_iter))

def test():
print("============================= TEST ============================")
inpaint_model.discriminator.eval()
inpaint_model.generator.eval()

if not os.path.exists(config.output):
os.mkdir(config.output)
for name in tqdm(os.listdir(config.image)):
if name.startswith("."):
continue
img = Image.open(os.path.join(config.image, name)).convert("RGB")
name = ".".join(name.split(".")[:-1])
mask = Image.open(os.path.join(config.mask, name+".png")).convert("L")
w, h = img.size
_w = (w//16+1)*16
_h = (h//16+1)*16
img = img.resize((_w, _h))
mask = mask.resize((_w, _h))
img = np.array(img).transpose((2, 0, 1))[None, ...]
mask = np.array(mask)[None, None, ...]
img = torch.Tensor(img).float().cuda() / 255.0
mask = torch.Tensor(mask>0).float().cuda()
outputs, gen_loss, dis_loss = inpaint_model.process(img, 1-mask)
img_out = outputs[0].detach().permute(1, 2, 0).cpu().numpy()*255.0
img_out[img_out>255.0] = 255.0
img_out = Image.fromarray(img_out.astype(np.uint8))
img_out = img_out.resize((w, h))
img_out.save(os.path.join(config.output, name+".jpg"))

# resize result to input
for name in os.listdir(config.image):
image = Image.open(os.path.join(
config.image, name)).convert("RGB")
name = ".".join(name.split("/")[-1].split(".")[:-1])
mask = Image.open(os.path.join(
config.mask, name+".png")).convert("L")
mask = np.array(mask)[..., None]
W, H = image.size
image = np.array(image)
result = Image.open(os.path.join(
config.output, name+".jpg"))
result = result.resize((W, H))
result = np.array(result)
result = result*(mask>0)+image*(mask==0)
result = Image.fromarray(result.astype(np.uint8))
result.save(os.path.join(
config.output, name+".jpg"))


if config.RESUME:
inpaint_model.load()

if config.test:
with torch.no_grad():
test()
else:
if not config.skip_training:
train()

if not config.skip_training:
train()
if not config.skip_validation:
with torch.no_grad():
eval()

if not config.skip_validation:
with torch.no_grad():
eval()