Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading pre-trained model configuration from Python file #4737

Closed
buckeye17 opened this issue Jan 4, 2023 · 9 comments
Closed

Loading pre-trained model configuration from Python file #4737

buckeye17 opened this issue Jan 4, 2023 · 9 comments
Labels
documentation Problems about existing documentation or comments

Comments

@buckeye17
Copy link

📚 Documentation Issue

I'm struggling to load the pre-trained model defined by new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py.
I've found relevant documentation here, here and issue #3225. However none of these clearly elucidate my error.

I'm trying to load the configuration with:

cfg = LazyConfig.load("detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py")
cfg = setup_cfg(args)

This produces the following traceback:

Traceback (most recent call last):
  File "quality_test.py", line 97, in <module>
    results_ls = get_person_seg_masks(img_path, model_family, model)
  File "detectron2_wrapper.py", line 107, in get_person_seg_masks
    cfg = setup_cfg(args)
  File "detectron2/demo/demo.py", line 29, in setup_cfg
    cfg.merge_from_file(args.config_file)
  File "/home/appuser/detectron2_repo/detectron2/config/config.py", line 46, in merge_from_file
    loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
  File "/home/appuser/.local/lib/python3.8/site-packages/fvcore/common/config.py", line 61, in load_yaml_with_base
    cfg = yaml.safe_load(f)
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/__init__.py", line 125, in safe_load
    return load(stream, SafeLoader)
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/__init__.py", line 81, in load
    return loader.get_single_data()
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/constructor.py", line 49, in get_single_data
    node = self.get_single_node()
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/composer.py", line 39, in get_single_node
    if not self.check_event(StreamEndEvent):
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/parser.py", line 98, in check_event
    self.current_event = self.state()
  File "/home/appuser/.local/lib/python3.8/site-packages/yaml/parser.py", line 171, in parse_document_start
    raise ParserError(None, None,
yaml.parser.ParserError: expected '<document start>', but found '<scalar>'
  in "detectron2/configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py", line 11, column 1
@buckeye17 buckeye17 added the documentation Problems about existing documentation or comments label Jan 4, 2023
@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Jan 5, 2023

demo.py does not support python configs

@aylinaydincs
Copy link

How can we see the demo of python configs such as MViT? and If we trained MViTv2 does do instance segmentation?

@buckeye17
Copy link
Author

Since I struggled to get Densepose working, I wanted to share my function for working with it. Hopefully other people can get going quicker with this as a guide.

def get_person_mesh_mask(
    img_path: str,
    file_out_path: str = "",
    model_family: str = "iuv",
    model: str = "rcnn_R_101_FPN_DL_WC1M_s1x"
):
    '''
    Generates segmentation mask for most prominent person's body in input image

    Input Parameters:
    img_path: defines input image to be analyzed, string like "/path/to/image.jpg"
    file_out_path: defines image files to be produced, string like "/path/to/image.jpg", must end with ".jpg", no images made if ""
    model_family: defines which family of models to use, must be one of: "cse" or "iuv"
    model: defines which model to use within provided model family, see model_dict for valid values depending on the model family
    
    Returns a list consisting of a bounding box, the segmentation mask and confidence mask
    NOTE: this function should be executed with detectron2_repo/projects/DensePose as the current working directory
    '''
    
    # this function utilizes detectron2/projects/DensePose/apply_net.py
    # documentation for it can be found at https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/projects/DensePose/doc/TOOL_APPLY_NET.md
    
    import subprocess
    import sys
    
    import numpy as np
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    import torch
    
    # the following model dictionary was built using the documentation at:
    # https://github.com/facebookresearch/detectron2/blob/main/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo
    # https://github.com/facebookresearch/detectron2/blob/main/projects/DensePose/doc/DENSEPOSE_CSE.md#ModelZoo
    # note that the dictionary does not include all available models mentioned in the documnetation above
    # based on ~10 test images, some of the best models are: iuv - rcnn_R_50_FPN_s1x, iuv - rcnn_R_101_FPN_s1x & iuv - rcnn_R_101_FPN_DL_WC1M_s1x
    # for most of ~10 test images, resulting segmentation masks were very similar
    model_dict = {
        "cse": {
            "rcnn_R_50_FPN_s1x": {
                "yaml": "configs/densepose_rcnn_R_50_FPN_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl"
            },
            "R_101_FPN_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_s1x/251155500/model_final_5c995f.pkl"
            },
            # NOTE: the following model fails to run!?
            "rcnn_R_50_FPN_DL_s1x": {
                "yaml": "configs/cse/densepose_rcnn_R_50_FPN_DL_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_DL_s1x/251156349/model_final_e96218.pkl"
            },
            # NOTE: the following model fails to run!?
            "rcnn_R_101_FPN_DL_s1x": {
                "yaml": "configs/cse/densepose_rcnn_R_101_FPN_DL_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_s1x/251156606/model_final_b236ce.pkl"
            }
        },
        "iuv": {
            "rcnn_R_50_FPN_s1x": {
                "yaml": "configs/densepose_rcnn_R_50_FPN_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl"
            },
            "rcnn_R_101_FPN_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_s1x/165712084/model_final_c6ab63.pkl"
            },
            "rcnn_R_50_FPN_DL_s1x": {
                "yaml": "configs/densepose_rcnn_R_50_FPN_DL_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_DL_s1x/165712097/model_final_0ed407.pkl"
            },
            "rcnn_R_101_FPN_DL_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_DL_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_DL_s1x/165712116/model_final_844d15.pkl"
            },
            "rcnn_R_101_FPN_WC1M_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_WC1M_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_WC1M_s1x/216453687/model_final_0a7287.pkl"
            },
            "rcnn_R_101_FPN_WC2M_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_WC2M_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_WC2M_s1x/216245682/model_final_e354d9.pkl"
            },
            "rcnn_R_101_FPN_DL_WC1M_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_DL_WC1M_s1x/216245771/model_final_0ebeb3.pkl"
            },
            "rcnn_R_101_FPN_DL_WC2M_s1x": {
                "yaml": "configs/densepose_rcnn_R_101_FPN_DL_WC2M_s1x.yaml",
                "weights": "https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_101_FPN_DL_WC2M_s1x/216245790/model_final_de6e7a.pkl"
            },
        }
    }
    
    temp_pkl_path = "/home/appuser/results.pkl" # this assumes the function is being executed in a detectron2 Docker container
    make_images_bool = file_out_path.endswith(".jpg")
    if make_images_bool:
        cmd = f"python apply_net.py show {model_dict[model_family][model]['yaml']} {model_dict[model_family][model]['weights']} {img_path} dp_segm -v --output {file_out_path}"
        cmd = cmd.split(" ")
        
        # returns output as byte string
        returned_output = subprocess.check_output(cmd)

        # using decode() function to convert byte string to string
        output_str = returned_output.decode("utf-8")
        
        # make subprocess outputs visible to terminal
        print(output_str)
        
    cmd = f"python apply_net.py dump {model_dict[model_family][model]['yaml']} {model_dict[model_family][model]['weights']} {img_path} -v --output {temp_pkl_path}"
    cmd = cmd.split(" ")
    
    # returns output as byte string
    returned_output = subprocess.check_output(cmd)

    # using decode() function to convert byte string to string
    output_str = returned_output.decode("utf-8")
    
    # make subprocess outputs visible to terminal
    print(output_str)
    
    # make sure DensePose is in your PYTHONPATH, or use the following line to add it:
    sys.path.append("/home/appuser/detectron2_repo/projects/DensePose/")
    
    with open(temp_pkl_path, "rb") as hFile:
        data = torch.load(hFile)
    
    data_dict = data[0]
    
    # valid keys for data_dict: 'file_name', 'scores', 'pred_boxes_XYXY', 'pred_densepose'
    # assume the first predicted item is the intended subject, as observed with test images
    bbox = data_dict["pred_boxes_XYXY"][0].cpu().numpy()
    pred_densepose = data_dict["pred_densepose"][0]
    
    # valid pred_densepose properties: coarse_segm_confidence, fine_segm_confidence, kappa_u, kappa_v, labels, sigma_1, sigma_2, to, uv
    mask = pred_densepose.labels.cpu().numpy()
    conf = pred_densepose.fine_segm_confidence.cpu().numpy()
        
    if make_images_bool and hasattr(pred_densepose.fine_segm_confidence, '__iter__'):
        fig = make_subplots(rows=1, cols=2, subplot_titles=["label", "confidence"])
        fig.add_trace(
            go.Heatmap(z=pred_densepose.labels.cpu().numpy()),
            row = 1,
            col = 1
        )
        fig.add_trace(
            go.Heatmap(z=pred_densepose.fine_segm_confidence.cpu().numpy()),
            row = 1,
            col = 2
        )
        fig['layout']['yaxis1']['autorange'] = "reversed"
        fig['layout']['yaxis1']['scaleanchor'] = "x1"
        fig['layout']['yaxis1']['scaleratio'] = 1
        fig['layout']['yaxis2']['autorange'] = "reversed"
        fig['layout']['yaxis2']['scaleanchor'] = "x2"
        fig['layout']['yaxis2']['scaleratio'] = 1
        
        file_out_base_path = file_out_path.replace(".jpg", "")
        fig.write_html(f"{file_out_base_path}_mask_plots.html")
    
    elif make_images_bool:
        # show confidence plot as all zeros
        labels = pred_densepose.labels.cpu().numpy()
        fig = make_subplots(rows=1, cols=2, subplot_titles=["label", "confidence"])
        fig.add_trace(
            go.Heatmap(z=labels),
            row = 1,
            col = 1
        )
        fig.add_trace(
            go.Heatmap(z=np.zeros_like(labels)),
            row = 1,
            col = 2
        )
        fig['layout']['yaxis1']['autorange'] = "reversed"
        fig['layout']['yaxis1']['scaleanchor'] = "x1"
        fig['layout']['yaxis1']['scaleratio'] = 1
        fig['layout']['yaxis2']['autorange'] = "reversed"
        fig['layout']['yaxis2']['scaleanchor'] = "x2"
        fig['layout']['yaxis2']['scaleratio'] = 1
        
        file_out_base_path = file_out_path.replace(".jpg", "")
        fig.write_html(f"{file_out_base_path}_mask_plots.html")
    
    return bbox, mask, conf

@buckeye17
Copy link
Author

I also wanted to mention that I haven't been able to figure out how to run inference on a pre-trained model defined by a Python LazyConfig. The closest guidance I could find on this is here. However, that script only seems to apply for evaluating common datasets. More work is required to apply the model to a custom dataset. There doesn't seem to be an easy way to run inference with LazyConfig models.

@ppwwyyxx
Copy link
Contributor

cfg = LazyConfig.load(...)
model = instantiate(cfg.model)

Then follow https://detectron2.readthedocs.io/en/latest/tutorials/models.html#load-save-a-checkpoint to load a checkpoint.
Then follow https://detectron2.readthedocs.io/en/latest/tutorials/models.html#use-a-model to run inference.

@buckeye17
Copy link
Author

@ppwwyyxx I tried following your recommendation with the following code:

cfg = LazyConfig.load(args.config_file)
cfg = LazyConfig.load("configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py")
model = instantiate(cfg.model)

from detectron2.checkpoint import DetectionCheckpointer
DetectionCheckpointer(model).load("detectron2://new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ/42073830/model_final_f96b26.pkl")  # load a file, usually from cfg.MODEL.WEIGHTS

# use PIL, to be consistent with evaluation
img = torch.from_numpy(np.ascontiguousarray(read_image(img_path, format="BGR")))
img = img.permute(2, 0, 1)  # HWC -> CHW
if torch.cuda.is_available():
    img = img.cuda()
inputs = [{"image": img}]
model.eval()
with torch.no_grad():
    predictions = model(inputs)

This generated the following error:

Traceback (most recent call last):
  File "/app/Anthropometry/quality_test.py", line 108, in <module>
    results_ls = get_person_seg_masks(img_path, model_family, model)
  File "/app/Anthropometry/detectron2_wrapper.py", line 152, in get_person_seg_masks
    predictions = model(inputs)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/detectron2_repo/detectron2/modeling/meta_arch/rcnn.py", line 150, in forward
    return self.inference(batched_inputs)
  File "/home/appuser/detectron2_repo/detectron2/modeling/meta_arch/rcnn.py", line 204, in inference
    features = self.backbone(images.tensor)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/detectron2_repo/detectron2/modeling/backbone/fpn.py", line 139, in forward
    bottom_up_features = self.bottom_up(x)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/detectron2_repo/detectron2/modeling/backbone/resnet.py", line 445, in forward
    x = self.stem(x)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/detectron2_repo/detectron2/modeling/backbone/resnet.py", line 356, in forward
    x = self.conv1(x)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/detectron2_repo/detectron2/layers/wrappers.py", line 117, in forward
    x = self.norm(x)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/appuser/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 683, in forward
    raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
ValueError: SyncBatchNorm expected input tensor to be on GPU

Any idea what my problem is? Thanks again for your help!

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Jan 14, 2023 via email

@buckeye17
Copy link
Author

Alright, well thanks to the help of @ppwwyyxx, I was able to run inference on a LazyConfig model with the following script!

import numpy as np
import torch

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.data.detection_utils import read_image

cfg = LazyConfig.load("configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py")

# edit the config to utilize common Batch Norm
cfg.model.backbone.bottom_up.stem.norm = "BN"
cfg.model.backbone.bottom_up.stages.norm = "BN"
cfg.model.backbone.norm = "BN"

model = instantiate(cfg.model)

DetectionCheckpointer(model).load("detectron2://new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ/42073830/model_final_f96b26.pkl")  # load a file, usually from cfg.MODEL.WEIGHTS

# read image for inference input
# use PIL, to be consistent with evaluation
img = torch.from_numpy(np.ascontiguousarray(read_image(img_path, format="BGR")))
img = img.permute(2, 0, 1)  # HWC -> CHW
if torch.cuda.is_available():
    img = img.cuda()
inputs = [{"image": img}]

# run the model
model.eval()
with torch.no_grad():
    predictions_ls = model(inputs)
predictions = predictions_ls[0]

@ShunLu91
Copy link

Alright, well thanks to the help of @ppwwyyxx, I was able to run inference on a LazyConfig model with the following script!

import numpy as np
import torch

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.data.detection_utils import read_image

cfg = LazyConfig.load("configs/new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ.py")

# edit the config to utilize common Batch Norm
cfg.model.backbone.bottom_up.stem.norm = "BN"
cfg.model.backbone.bottom_up.stages.norm = "BN"
cfg.model.backbone.norm = "BN"

model = instantiate(cfg.model)

DetectionCheckpointer(model).load("detectron2://new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ/42073830/model_final_f96b26.pkl")  # load a file, usually from cfg.MODEL.WEIGHTS

# read image for inference input
# use PIL, to be consistent with evaluation
img = torch.from_numpy(np.ascontiguousarray(read_image(img_path, format="BGR")))
img = img.permute(2, 0, 1)  # HWC -> CHW
if torch.cuda.is_available():
    img = img.cuda()
inputs = [{"image": img}]

# run the model
model.eval()
with torch.no_grad():
    predictions_ls = model(inputs)
predictions = predictions_ls[0]

This works for me! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Problems about existing documentation or comments
Projects
None yet
Development

No branches or pull requests

4 participants