From 9d63edfc7e2644d4435d3f38c73cdd78025c84ff Mon Sep 17 00:00:00 2001 From: tsugumi-sys Date: Sun, 7 Jan 2024 22:03:39 +0900 Subject: [PATCH] fix arg-type error --- pipelines/moving_mnist_pipeline/data_loader.py | 4 ++-- pyproject.toml | 8 +------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pipelines/moving_mnist_pipeline/data_loader.py b/pipelines/moving_mnist_pipeline/data_loader.py index 01d10b3..227e065 100644 --- a/pipelines/moving_mnist_pipeline/data_loader.py +++ b/pipelines/moving_mnist_pipeline/data_loader.py @@ -1,12 +1,12 @@ from typing import Tuple import torch -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, Dataset, Subset, random_split from torchvision.datasets import MovingMNIST class VideoPredictionDataset(Dataset): - def __init__(self, data: torch.Tensor, input_frames: int = 10): + def __init__(self, data: Subset, input_frames: int = 10): self.data = data self.input_frames = input_frames diff --git a/pyproject.toml b/pyproject.toml index 75b2e5c..eadafcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,10 +30,4 @@ split-on-trailing-comma = true namespace_packages = true ignore_missing_imports = true python_version = "3.11" -disable_error_code = [ - "arg-type", - "assignment", - "override", - "index", - "var-annotated", -] +disable_error_code = ["assignment", "override", "index", "var-annotated"]