Play LEGO with papers.
#all_slow
fastpapers
is a python library where I use fastai to reproduce papers on Jupyter Notebooks. I use nbdev to turn these notebooks into modules.
pip install fastpapers
Download the data
path = download_coco(force_download=False)
Create the DataLoaders, the Learner, and fit.
dls = CocoDataLoaders.from_sources(path, vocab=coco_vocab, num_workers=0)
learnd = detr_learner(dls)
learnd.fit(1, lr=[1e-5, 1e-5, 1e-5])
epoch | train_loss | valid_loss | AP | AP50 | AP75 | AP_small | AP_medium | AP_large | AR1 | AR10 | AR100 | AR_small | AR_medium | AR_large | time |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 5.892842 | 7.636298 | 0.475381 | 0.574125 | 0.506063 | 0.297741 | 0.458006 | 0.560994 | 0.355018 | 0.545646 | 0.560374 | 0.375141 | 0.541728 | 0.630330 | 2:05:24 |
Show the results
with learnd.removed_cbs(learnd.coco_eval): learnd.show_results(max_n=8, figsize=(10,10))
Download the data
path = untar_data(URLs.IMAGENETTE)
Create the DataLoaders, the Learner adn fit.
#hide_output
db = DataBlock(blocks=(ResImageBlock(72), ResImageBlock(288)),
get_items=get_image_files,
batch_tfms=Normalize.from_stats([0.5]*3, [0.5]*3))
dls = db.dataloaders(path, bs=4, num_workers=4)
learn = superres_learner(dls)
learn.fit(16, lr=1e-3, wd=0)
learn.show_results()
The name of each module is the bibtexkey of the corresponing paper. For example, if you want to use the FID metric from Heusel, Martin, et al. 2017, you can import it like so:
from fastpapers.heusel2017gans import FIDMetric
If you want to train a pix2pix model from Isola, Phillip, et al you can import a pix2pix_learner
from fastpapers.isola2017image import pix2pix_learner
The core
module contains functions and classes that are useful for several papers.
For example, you have a ImageNTuple
to work with an arbitrary amount of images as input.
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
it = ImageNTuple.create((files[0], files[1], files[2]))
it = Resize(224)(it)
it = ToTensor()(it)
it.show();
Or useful functions for debuging like explode_shapes
or explode_ranges
explode_shapes(it)
[(3, 224, 224), (3, 224, 224), (3, 224, 224)]