Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feat/change-installation
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored May 29, 2021
2 parents 7079884 + 9c5dbd7 commit bc98837
Show file tree
Hide file tree
Showing 28 changed files with 306 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Want to add a new test case and not sure how? [Talk to us!](https://join.slack.c

## Guidelines

For this section, we refer to read the [parent PL guidelines](https://pytorch-lightning.readthedocs.io/en/latest/CONTRIBUTING.html)
For this section, we refer to read the [parent PL guidelines](https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html)

**Reminder**

Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346))


- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


## [0.3.0] - 2021-05-20

### Added
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def _load_py_module(fname, pkg="flash"):
#
source_suffix = [".rst", ".md"]

needs_sphinx = '3.4'
needs_sphinx = "4.0"

# -- Options for intersphinx extension ---------------------------------------

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
Expand Down
2 changes: 1 addition & 1 deletion flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
):

if not _IMAGE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install . '[image]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'")

self.save_hyperparameters()

Expand Down
2 changes: 1 addition & 1 deletion flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
pooling_fn: Callable = torch.max
):
if not _IMAGE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[image]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'")

super().__init__(
model=None,
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SemanticSegmentationPathsDataSource(PathsDataSource):

def __init__(self):
if not _IMAGE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[image]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'")
super().__init__(IMG_EXTENSIONS)

def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]],
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
image_size: A tuple with the expected output image size.
"""
if not _IMAGE_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[image]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'")
self.image_size = image_size
self.num_classes = num_classes
if num_classes:
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
) -> None:

if isinstance(backbone, str) and (not _TORCHVISION_AVAILABLE or not _TIMM_AVAILABLE):
raise ModuleNotFoundError("Please, pip install -e '.[image]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'")

if metrics is None:
metrics = IoU(num_classes=num_classes)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/style_transfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
):

if not _IMAGE_STLYE_TRANSFER:
raise ModuleNotFoundError("Please, pip install -e '.[image_style_transfer]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[image_style_transfer]'")

self.save_hyperparameters(ignore="style_image")

Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
**tabnet_kwargs,
):
if not _TABULAR_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[tabular]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[tabular]'")

self.save_hyperparameters()

Expand Down
50 changes: 48 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(self, backbone: str, max_length: int = 128):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length

Expand All @@ -55,6 +56,15 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str,
ex[target] = label_to_class_mapping[ex[target]]
return ex

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextFileDataSource(TextDataSource):

Expand Down Expand Up @@ -115,18 +125,45 @@ def load_data(
def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextCSVDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("csv", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextJSONDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("json", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):

Expand All @@ -143,6 +180,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s, ) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextClassificationPreprocess(Preprocess):

Expand All @@ -157,7 +203,7 @@ def __init__(
):

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

self.backbone = backbone
self.max_length = max_length
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

self.save_hyperparameters()

Expand Down
94 changes: 90 additions & 4 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import flash
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _TEXT_AVAILABLE

Expand All @@ -43,9 +43,10 @@ def __init__(
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.padding = padding
Expand All @@ -71,6 +72,15 @@ def _tokenize_fn(
padding=self.padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqFileDataSource(Seq2SeqDataSource):

Expand Down Expand Up @@ -112,6 +122,15 @@ def load_data(
def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]:
return self.load_data(data, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqCSVDataSource(Seq2SeqFileDataSource):

Expand All @@ -130,6 +149,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqJSONDataSource(Seq2SeqFileDataSource):

Expand All @@ -148,6 +176,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqSentencesDataSource(Seq2SeqDataSource):

Expand All @@ -161,6 +198,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
Expand Down Expand Up @@ -190,7 +236,7 @@ def __init__(
self.padding = padding

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

super().__init__(
train_transform=train_transform,
Expand Down Expand Up @@ -240,7 +286,47 @@ def collate(self, samples: Any) -> Tensor:
return default_data_collator(samples)


class Seq2SeqPostprocess(Postprocess):

def __init__(self):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_str = [str.strip(s) for s in pred_str]
return pred_str

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("_tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqData(DataModule):
"""Data module for Seq2Seq tasks."""

preprocess_cls = Seq2SeqPreprocess
postprocess_cls = Seq2SeqPostprocess
2 changes: 1 addition & 1 deletion flash/text/seq2seq/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
num_beams: Optional[int] = None,
):
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")
raise ModuleNotFoundError("Please, pip install 'lightning-flash[text]'")

os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand Down
Loading

0 comments on commit bc98837

Please sign in to comment.