Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training hybridnet on multiple datasets with different number of cameras #8

Open
timsainb opened this issue May 4, 2023 · 1 comment

Comments

@timsainb
Copy link

timsainb commented May 4, 2023

Hey Timo,

I'm looking into pretraining a keypoint network again and I'm having trouble figuring out how to follow the instructions you gave via email (shown below).

I found the best way to use a different dataset for pretraining is as follows:

  1. Train a model on the pretraining dataset (in your case the combined rodent annotations). You can just use the train all function for this step.
  2. Retrain the full Hybridnet using the Training Mode 'all' WITHOUT retraining the EfficientNet 2D detector separately. It might be worth experimenting around with the learning rate in this step a little bit (lower ones might work better).
  3. Also train the CenterDetect network, initializing it with the pretrained weights.

This works because the structure of the 3D network does not depend on the number of cameras used, so loading the pretrained HybridNet weights shouldn't be a problem, regardless of the camera configuration.
With this training strategy I got very solid results on a hand-tracking task with only 50 labeled frame-sets, even though it used only 7 cameras (instead of the 12 in the pretraining dataset) and lighting and background where completely different.

I have three types of datasets:

  • keypoints from a 10 camera rig separate from my own dataset
  • keypoints from my rig, with 6 cameras (including bottom up)
  • keypoints from my rig with 5 cameras (no bottom camera)

I'm trying to do inference on new data from the 5-camera rig, but since I have all of these other labeled data, I want to use them all together for training (at least the 2D network.

So my understanding is there are three steps in the Jarvis pipeline (center detect, 2D keypoints, 3D hybridnet).
For your first step (train a model on the pretraining dataset) I first made a training set of all of the datasets combined. I then train a center detect network

train_interface.train_efficienttrack(
    'CenterDetect', 
    project_name,
    num_epochs, 
    weights = weights
)

Then I train a 2D keypoint network:

train_interface.train_efficienttrack(
    'KeypointDetect', 
    project_name,
    num_epochs, 
    weights = 'latest',
)

Both of these work fine and I get pretty good performance. The issue is when I then try to train the hybridnet:

train_interface.train_hybridnet(
    project_name=project_name, 
    num_epochs = num_epochs,
    weights_keypoint_detect=weights_keypoint_detect, 
    weights=weights_hybridnet,
    mode=mode, 
    finetune =finetune
)

Training starts to run smoothly for a few batches:

Successfully loaded project 23-05-02-train-chronic-only.
[Info] Training HybridNet on project 23-05-02-train-chronic-only for 100 epochs!
[Info] Successfully loaded weights: /n/groups/datta/tim_sainburg/projects/JARVIS-HybridNet/projects/23-05-02-train-chronic-only/models/KeypointDetect/Run_20230504-002624/EfficientTrack-medium_final.pth
Epoch: 1/100. Loss: 298.4349. Acc: 68.30:   1%|▏         | 6/479 [00:02<03:30,  2.24it/s]

Then get the following error:

IndexError: Caught IndexError in DataLoader worker process 6.
Original Traceback (most recent call last):
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tis697/.conda/envs/jarvis_launcher/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/n/groups/datta/tim_sainburg/projects/JARVIS-HybridNet/jarvis/dataset/dataset3D.py", line 204, in __getitem__
    centerHM[frame_idx] = np.array([center_x, center_y])
IndexError: index 5 is out of bounds for axis 0 with size 5

I believe what is happening is that since the 'dataset3D' object only has one version of self.reproTools, the network always expects the same reprojection parameters (despite different parameters existing for each trainingset, inside calib_params.

Any help on this method would be appreciated!

Thanks

@timsainb
Copy link
Author

timsainb commented May 4, 2023

I think I localized the issue. The problem is that with multiple cameras, you're resetting the number of cameras for each calibration you add, so its only the last callibration that counts.

Create a trainingset
https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/train_interface.py#L15

training_set = Dataset3D(project.cfg, set='train',
cameras_to_use = camera_list)

in the trainingset, create a calibration/reprotool for each dataset
https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/dataset/dataset3D.py#L50
self.reproTools = {}
for calibParams in self.dataset['calibrations']:
            calibPaths = {}
            for cam in self.dataset['calibrations'][calibParams]:
                if self.cameras_to_use == None or cam in self.cameras_to_use:
                    calibPaths[cam] = self.dataset['calibrations'] \
                                [calibParams][cam]
            self.reproTools[calibParams] = ReprojectionTool(
                        os.path.join(cfg.PARENT_DIR, self.root_dir), calibPaths)
            **self.num_cameras = self.reproTools[calibParams].num_cameras**
            self.reproTools[calibParams].resolution = [width,height]

So later when we're trying to grab something from the dataset, it's the last number of cameras that it expects:
https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/dataset/dataset3D.py#L189
centerHM = np.full((self.num_cameras, 2), 128, dtype = int)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant