diff --git a/resolve.py b/resolve.py index b760e4e..c1fc908 100644 --- a/resolve.py +++ b/resolve.py @@ -58,7 +58,7 @@ def super_resolve(img: str, halve: bool = False): sr_img_srresnet = tf.squeeze(sr_img_srresnet['output_0']) sr_img_srresnet = transform.convert_image(sr_img_srresnet, source='[-1, 1]', target='pil') - sr_img_srgan = resnet_inference(lr_img) + sr_img_srgan = generator_inference(lr_img) sr_img_srgan = tf.squeeze(sr_img_srgan['output_0']) sr_img_srgan = transform.convert_image(sr_img_srgan, source='[-1, 1]', target='pil') diff --git a/train.py b/train.py index d02f8e3..c7c7576 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,7 @@ large_kernel_size = 9 # kernel size of the first and last convolutions which transform the inputs and outputs small_kernel_size = 3 # kernel size of the first and last convolutions which transform the inputs and outputs n_channels = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks -n_blocks = 16 # number of residual blocks +n_blocks = 32 # number of residual blocks srresnet_checkpoint = "SuperResolutionResNet_9999" # trained SRResNet checkpoint used for initialization # Discriminator parameters @@ -94,4 +94,4 @@ def main(architecture_type: str = "resnet"): if __name__ == "__main__": - main(architecture_type="gan") + main(architecture_type="resnet")