From 2b6823f72665ce7c88cf8bc7aa784d9ecaacf4f5 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 7 Nov 2023 16:17:10 +0800 Subject: [PATCH 1/5] support lvis and chund eval --- demo/image_demo.py | 37 ++- mmdet/evaluation/functional/class_names.py | 247 ++++++++++++++++++++- mmdet/models/detectors/glip.py | 232 ++++++++++++++----- 3 files changed, 457 insertions(+), 59 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index 2e2c27adbf2..5a9c906cef0 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -28,6 +28,16 @@ glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ --texts 'There are a lot of cars here.' + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ + --texts '$: coco' + + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ + --texts '$: lvis' --pred-score-thr 0.7 \ + --palette random --chunked-size 80 + + Visualize prediction results:: python demo/image_demo.py demo/demo.jpg rtmdet-ins-s --show @@ -41,6 +51,7 @@ from mmengine.logging import print_log from mmdet.apis import DetInferencer +from mmdet.evaluation import get_classes def parse_args(): @@ -60,7 +71,12 @@ def parse_args(): type=str, default='outputs', help='Output directory of images or prediction results.') - parser.add_argument('--texts', help='text prompt') + # Once you input a format similar to $: xxx, it indicates that + # the prompt is based on the dataset class name. + # support $: coco, $: voc, $: cityscapes, $: lvis, $: imagenet_det. + # detail to `mmdet/evaluation/functional/class_names.py` + parser.add_argument( + '--texts', help='text prompt, such as "bench . car .", "$: coco"') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -91,7 +107,7 @@ def parse_args(): default='none', choices=['coco', 'voc', 'citys', 'random', 'none'], help='Color palette used for visualization') - # only for GLIP + # only for GLIP and Grounding DINO parser.add_argument( '--custom-entities', '-c', @@ -99,6 +115,13 @@ def parse_args(): help='Whether to customize entity names? ' 'If so, the input text should be ' '"cls_name1 . cls_name2 . cls_name3 ." format') + parser.add_argument( + '--chunked-size', + '-s', + type=int, + default=-1, + help='If the number of categories is very large, ' + 'you can specify this parameter to truncate multiple predictions.') call_args = vars(parser.parse_args()) @@ -111,6 +134,12 @@ def parse_args(): call_args['weights'] = call_args['model'] call_args['model'] = None + if call_args['texts'] is not None: + if call_args['texts'].startswith('$:'): + dataset_name = call_args['texts'][3:].strip() + class_names = get_classes(dataset_name) + call_args['texts'] = [tuple(class_names)] + init_kws = ['model', 'weights', 'device', 'palette'] init_args = {} for init_kw in init_kws: @@ -125,6 +154,10 @@ def main(): # may consume too much memory if your input folder has a lot of images. # We will be optimized later. inferencer = DetInferencer(**init_args) + + chunked_size = call_args.pop('chunked_size') + inferencer.model.test_cfg.chunked_size = chunked_size + inferencer(**call_args) if call_args['out_dir'] != '' and not (call_args['no_save_vis'] diff --git a/mmdet/evaluation/functional/class_names.py b/mmdet/evaluation/functional/class_names.py index d0ea7094685..623a89cfdc0 100644 --- a/mmdet/evaluation/functional/class_names.py +++ b/mmdet/evaluation/functional/class_names.py @@ -485,6 +485,250 @@ def objects365v2_classes() -> list: ] +def lvis_classes() -> list: + """Class names of LVIS.""" + return [ + 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', + 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', + 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', + 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', + 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', + 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', + 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', + 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', + 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', + 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', + 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', + 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', + 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', + 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', + 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', + 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', + 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', + 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', + 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', + 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', + 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', + 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', + 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', + 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', + 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', + 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', + 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', + 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', + 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', + 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer', + 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', + 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', + 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', + 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', + 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', + 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', + 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', + 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', + 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', + 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', + 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', + 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', + 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', + 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', + 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', + 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', + 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', + 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', + 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', + 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', + 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', + 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', + 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', + 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', + 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', + 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', + 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', + 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', + 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', + 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', + 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', + 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', + 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', + 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', + 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', + 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', + 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', + 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', + 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', + 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', + 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', + 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', + 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', + 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', + 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', + 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', + 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', + 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', + 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', + 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', + 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', + 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', + 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', + 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', + 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', + 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', + 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', + 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', + 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', + 'folding_chair', 'food_processor', 'football_(American)', + 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', + 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', + 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', + 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', + 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', + 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', + 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', + 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', + 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', + 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', + 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', + 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', + 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', + 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', + 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', + 'headboard', 'headlight', 'headscarf', 'headset', + 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', + 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', + 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', + 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', + 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', + 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', + 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', + 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', + 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', + 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', + 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', + 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', + 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', + 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', + 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', + 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', + 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', + 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', + 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat', + 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', + 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', + 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', + 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', + 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', + 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', + 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', + 'mitten', 'mixer_(kitchen_tool)', 'money', + 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', + 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', + 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', + 'music_stool', 'musical_instrument', 'nailfile', 'napkin', + 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', + 'newsstand', 'nightshirt', 'nosebag_(for_animals)', + 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', + 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', + 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', + 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', + 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', + 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', + 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', + 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', + 'parchment', 'parka', 'parking_meter', 'parrot', + 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', + 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', + 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', + 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', + 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', + 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', + 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', + 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', + 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', + 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', + 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', + 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', + 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', + 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', + 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', + 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', + 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', + 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', + 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', + 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', + 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', + 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', + 'recliner', 'record_player', 'reflector', 'remote_control', + 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', + 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', + 'rolling_pin', 'root_beer', 'router_(computer_equipment)', + 'rubber_band', 'runner_(carpet)', 'plastic_bag', + 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', + 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', + 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', + 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', + 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', + 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', + 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', + 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', + 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', + 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', + 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', + 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', + 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', + 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', + 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', + 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', + 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', + 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', + 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', + 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', + 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', + 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', + 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', + 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', + 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', + 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', + 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', + 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', + 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', + 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', + 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', + 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', + 'tambourine', 'army_tank', 'tank_(storage_vessel)', + 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', + 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', + 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', + 'telephone_pole', 'telephoto_lens', 'television_camera', + 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', + 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', + 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', + 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', + 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', + 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', + 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', + 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', + 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', + 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', + 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', + 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', + 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', + 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', + 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', + 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', + 'washbasin', 'automatic_washer', 'watch', 'water_bottle', + 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', + 'water_gun', 'water_scooter', 'water_ski', 'water_tower', + 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', + 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', + 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', + 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', + 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', + 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', + 'yoke_(animal_equipment)', 'zebra', 'zucchini' + ] + + dataset_aliases = { 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], @@ -496,7 +740,8 @@ def objects365v2_classes() -> list: 'oid_challenge': ['oid_challenge', 'openimages_challenge'], 'oid_v6': ['oid_v6', 'openimages_v6'], 'objects365v1': ['objects365v1', 'obj365v1'], - 'objects365v2': ['objects365v2', 'obj365v2'] + 'objects365v2': ['objects365v2', 'obj365v2'], + 'lvis': ['lvis', 'lvis_v1'], } diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py index 13cfea960a8..4011e73d09f 100644 --- a/mmdet/models/detectors/glip.py +++ b/mmdet/models/detectors/glip.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import re import warnings from typing import Optional, Tuple, Union @@ -166,6 +167,27 @@ def create_positive_map_label_to_token(positive_map: Tensor, return positive_map_label_to_token +def clean_label_name(name: str) -> str: + name = re.sub(r'\(.*\)', '', name) + name = re.sub(r'_', ' ', name) + name = re.sub(r' ', ' ', name) + return name + + +def chunks(lst: list, n: int) -> list: + """Yield successive n-sized chunks from lst.""" + all_ = [] + for i in range(0, len(lst), n): + data_index = lst[i:i + n] + all_.append(data_index) + counter = 0 + for i in all_: + counter += len(i) + assert (counter == len(lst)) + + return all_ + + @MODELS.register_module() class GLIP(SingleStageDetector): """Implementation of `GLIP `_ @@ -207,6 +229,46 @@ def __init__(self, self._special_tokens = '. ' + def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + if word in enhanced_text_prompts: + enhanced_text_dict = enhanced_text_prompts[word] + if 'prefix' in enhanced_text_dict: + caption_string += enhanced_text_dict['prefix'] + start_i = len(caption_string) + if 'name' in enhanced_text_dict: + caption_string += enhanced_text_dict['name'] + else: + caption_string += word + end_i = len(caption_string) + tokens_positive.append([[start_i, end_i]]) + + if 'suffix' in enhanced_text_dict: + caption_string += enhanced_text_dict['suffix'] + else: + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + + if idx != len(original_caption) - 1: + caption_string += self._special_tokens + return caption_string, tokens_positive + + def to_plain_text_prompts(self, original_caption): + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + if idx != len(original_caption) - 1: + caption_string += self._special_tokens + return caption_string, tokens_positive + def get_tokens_and_prompts( self, original_caption: Union[str, list, tuple], @@ -221,44 +283,14 @@ def get_tokens_and_prompts( original_caption = list( filter(lambda x: len(x) > 0, original_caption)) + original_caption = [clean_label_name(i) for i in original_caption] + if custom_entities and enhanced_text_prompts is not None: - caption_string = '' - tokens_positive = [] - for idx, word in enumerate(original_caption): - if word in enhanced_text_prompts: - enhanced_text_dict = enhanced_text_prompts[word] - if 'prefix' in enhanced_text_dict: - caption_string += enhanced_text_dict['prefix'] - start_i = len(caption_string) - if 'name' in enhanced_text_dict: - caption_string += enhanced_text_dict['name'] - else: - caption_string += word - end_i = len(caption_string) - tokens_positive.append([[start_i, end_i]]) - - if 'suffix' in enhanced_text_dict: - caption_string += enhanced_text_dict['suffix'] - else: - tokens_positive.append([[ - len(caption_string), - len(caption_string) + len(word) - ]]) - caption_string += word - - if idx != len(original_caption) - 1: - caption_string += self._special_tokens + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption, enhanced_text_prompts) else: - caption_string = '' - tokens_positive = [] - for idx, word in enumerate(original_caption): - tokens_positive.append([[ - len(caption_string), - len(caption_string) + len(word) - ]]) - caption_string += word - if idx != len(original_caption) - 1: - caption_string += self._special_tokens + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption) tokenized = self.language_model.tokenizer([caption_string], return_tensors='pt') @@ -285,14 +317,73 @@ def get_tokens_positive_and_prompts( custom_entities: bool = False, enhanced_text_prompt: Optional[ConfigType] = None ) -> Tuple[dict, str, Tensor, list]: - tokenized, caption_string, tokens_positive, entities = \ - self.get_tokens_and_prompts( - original_caption, custom_entities, enhanced_text_prompt) - positive_map_label_to_token, positive_map = self.get_positive_map( - tokenized, tokens_positive) + chunked_size = self.test_cfg.get('chunked_size', -1) + if not self.training and chunked_size > 0: + assert isinstance(original_caption, + (list, tuple)) or custom_entities is True + all_output = self.get_tokens_positive_and_prompts_chunked( + original_caption, enhanced_text_prompt) + positive_map_label_to_token, \ + caption_string, \ + positive_map, \ + entities = all_output + else: + tokenized, caption_string, tokens_positive, entities = \ + self.get_tokens_and_prompts( + original_caption, custom_entities, enhanced_text_prompt) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + if tokenized.input_ids.shape[1] > self.language_model.max_tokens: + warnings.warn('Inputting a text that is too long will result ' + 'in poor prediction performance. ' + 'Please reduce the text length.') return positive_map_label_to_token, caption_string, \ positive_map, entities + def get_tokens_positive_and_prompts_chunked( + self, + original_caption: Union[list, tuple], + enhanced_text_prompts: Optional[ConfigType] = None): + chunked_size = self.test_cfg.get('chunked_size', -1) + original_caption = [clean_label_name(i) for i in original_caption] + + original_caption_chunked = chunks(original_caption, chunked_size) + ids_chunked = chunks( + list(range(1, + len(original_caption) + 1)), chunked_size) + + positive_map_label_to_token_chunked = [] + caption_string_chunked = [] + positive_map_chunked = [] + entities_chunked = [] + + for i in range(len(ids_chunked)): + if enhanced_text_prompts is not None: + caption_string, tokens_positive = self.to_enhance_text_prompts( + original_caption_chunked[i], enhanced_text_prompts) + else: + caption_string, tokens_positive = self.to_plain_text_prompts( + original_caption_chunked[i]) + tokenized = self.language_model.tokenizer([caption_string], + return_tensors='pt') + if tokenized.input_ids.shape[1] > self.language_model.max_tokens: + warnings.warn('Inputting a text that is too long will result ' + 'in poor prediction performance. ' + 'Please reduce the --chunked-size.') + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + + caption_string_chunked.append(caption_string) + positive_map_label_to_token_chunked.append( + positive_map_label_to_token) + positive_map_chunked.append(positive_map) + entities_chunked.append(original_caption_chunked[i]) + + return positive_map_label_to_token_chunked, \ + caption_string_chunked, \ + positive_map_chunked, \ + entities_chunked + def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: # TODO: Only open vocabulary tasks are supported for training now. @@ -376,12 +467,14 @@ def predict(self, - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ - text_prompts = [ - data_samples.text for data_samples in batch_data_samples - ] - enhanced_text_prompts = [ - data_samples.caption_prompt for data_samples in batch_data_samples - ] + text_prompts = [] + enhanced_text_prompts = [] + for data_samples in batch_data_samples: + text_prompts.append(data_samples.text) + if 'caption_prompt' in data_samples: + enhanced_text_prompts.append(data_samples.caption_prompt) + else: + enhanced_text_prompts.append(None) if 'custom_entities' in batch_data_samples[0]: # Assuming that the `custom_entities` flag @@ -409,18 +502,45 @@ def predict(self, token_positive_maps, text_prompts, _, entities = zip( *_positive_maps_and_prompts) - language_dict_features = self.language_model(list(text_prompts)) + visual_features = self.extract_feat(batch_inputs) - for i, data_samples in enumerate(batch_data_samples): - data_samples.token_positive_map = token_positive_maps[i] + if isinstance(text_prompts[0], list): + # chunked text prompts, only bs=1 is supported + assert len(batch_inputs) == 1 + count = 0 + results_list = [] + + entities = [[item for lst in entities[0] for item in lst]] + + for b in range(len(text_prompts[0])): + text_prompts_once = [text_prompts[0][b]] + token_positive_maps_once = token_positive_maps[0][b] + language_dict_features = self.language_model(text_prompts_once) + batch_data_samples[ + 0].token_positive_map = token_positive_maps_once + + pred_instances = self.bbox_head.predict( + copy.deepcopy(visual_features), + language_dict_features, + batch_data_samples, + rescale=rescale)[0] + + if len(pred_instances) > 0: + pred_instances.labels += count + count += len(token_positive_maps_once) + results_list.append(pred_instances) + results_list = [results_list[0].cat(results_list)] + else: + language_dict_features = self.language_model(list(text_prompts)) - visual_features = self.extract_feat(batch_inputs) + for i, data_samples in enumerate(batch_data_samples): + data_samples.token_positive_map = token_positive_maps[i] - results_list = self.bbox_head.predict( - visual_features, - language_dict_features, - batch_data_samples, - rescale=rescale) + results_list = self.bbox_head.predict( + visual_features, + language_dict_features, + batch_data_samples, + rescale=rescale) for data_sample, pred_instances, entity in zip(batch_data_samples, results_list, entities): From f58ef44ea8de862820d02bbca49af1f5002ece54 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 7 Nov 2023 19:43:59 +0800 Subject: [PATCH 2/5] support lvis fix map --- ...in-t_a_fpn_dyhead_pretrain_zershot_lvis.py | 25 +++ mmdet/evaluation/metrics/lvis_metric.py | 182 +++++++++++++++++- 2 files changed, 202 insertions(+), 5 deletions(-) create mode 100644 configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py diff --git a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py new file mode 100644 index 00000000000..153ecdb0992 --- /dev/null +++ b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py @@ -0,0 +1,25 @@ +_base_ = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' + +model = dict( + test_cfg=dict( + max_per_img=300, + chunked_size=40, + ) +) + +dataset_type = 'LVISV1Dataset' +data_root = 'data/coco/' + +val_dataloader = dict( + dataset=dict( + data_root=data_root, + type=dataset_type, + ann_file='annotations/lvis_v1_minival_inserted_image_name.json', + data_prefix=dict(img=''))) +test_dataloader = val_dataloader + +val_evaluator = dict( + _delete_=True, + type='LVISFixedAPMetric', + ann_file=data_root + 'annotations/lvis_v1_minival_inserted_image_name.json') +test_evaluator = val_evaluator diff --git a/mmdet/evaluation/metrics/lvis_metric.py b/mmdet/evaluation/metrics/lvis_metric.py index e4dd6141c0e..819537400f7 100644 --- a/mmdet/evaluation/metrics/lvis_metric.py +++ b/mmdet/evaluation/metrics/lvis_metric.py @@ -4,23 +4,30 @@ import tempfile import warnings from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union - +from typing import Dict, List, Optional, Sequence, Union, Any +from mmengine.logging import print_log +import torch import numpy as np from mmengine.fileio import get_local_path from mmengine.logging import MMLogger from terminaltables import AsciiTable - +from mmengine.evaluator.metric import _to_cpu from mmdet.registry import METRICS from mmdet.structures.mask import encode_mask_results from ..functional import eval_recalls from .coco_metric import CocoMetric +from mmengine.evaluator import BaseMetric +from collections import defaultdict +import logging +from mmengine.dist import all_gather_object, is_main_process, broadcast_object_list try: import lvis + if getattr(lvis, '__version__', '0') >= '10.5.3': warnings.warn( - 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 + 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', + # noqa: E501 UserWarning) from lvis import LVIS, LVISEval, LVISResults except ImportError: @@ -122,7 +129,8 @@ def __init__(self, raise RuntimeError( 'The `file_client_args` is deprecated, ' 'please use `backend_args` instead, please refer to' - 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' + # noqa: E501 ) # if ann_file is not specified, @@ -362,3 +370,167 @@ def compute_metrics(self, results: list) -> Dict[str, float]: if tmp_dir is not None: tmp_dir.cleanup() return eval_results + + +def _merge_lists(listA, listB, maxN, key): + result = [] + indA, indB = 0, 0 + while (indA < len(listA) or indB < len(listB)) and len(result) < maxN: + if (indB < len(listB)) and (indA >= len(listA) or key(listA[indA]) < key(listB[indB])): + result.append(listB[indB]) + indB += 1 + else: + result.append(listA[indA]) + indA += 1 + return result + + +@METRICS.register_module() +class LVISFixedAPMetric(BaseMetric): + default_prefix: Optional[str] = 'lvis_fixed_ap' + + def __init__(self, + ann_file: str, + topk: int = 10000, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + backend_args: dict = None) -> None: + + if lvis is None: + raise RuntimeError( + 'Package lvis is not installed. Please run "pip install ' + 'git+https://github.com/lvis-dataset/lvis-api.git".') + super().__init__(collect_device=collect_device, prefix=prefix) + + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.outfile_prefix = outfile_prefix + self.backend_args = backend_args + + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._lvis_api = LVIS(local_path) + + # handle dataset lazy init + self.cat_ids = None + self.img_ids = None + + self.results = {} + self.topk = topk + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + cur_results = [] + for data_sample in data_samples: + pred = data_sample['pred_instances'] + xmin, ymin, xmax, ymax = pred['bboxes'].cpu().unbind(1) + boxes = torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1).tolist() + + scores = pred['scores'].cpu().numpy() + labels = pred['labels'].cpu().numpy() + + if len(boxes) == 0: + continue + + cur_results.extend( + [ + { + "image_id": data_sample['img_id'], + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + + by_cat = defaultdict(list) + for ann in cur_results: + by_cat[ann["category_id"]].append(ann) + + for cat, cat_anns in by_cat.items(): + if cat not in self.results: + self.results[cat] = [] + + cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk] + self.results[cat] = _merge_lists(self.results[cat], cur, self.topk, key=lambda x: x["score"]) + + def compute_metrics(self, results: dict) -> dict: + logger: MMLogger = MMLogger.get_current_instance() + + new_results = [] + + missing_dets_cats = set() + for cat, cat_anns in results.items(): + if len(cat_anns) < self.topk: + missing_dets_cats.add(cat) + new_results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]) + + if missing_dets_cats: + logger.info( + f"\n===\n" + f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n" + f"Outputting {self.topk} detections for each class will improve AP further.\n" + f"If using detectron2, please use the lvdevil/infer_topk.py script to " + f"output a results file with {self.topk} detections for each class.\n" + f"===" + ) + + new_results = LVISResults(self._lvis_api, new_results, max_dets=-1) + lvis_eval = LVISEval(self._lvis_api, new_results, iou_type="bbox") + params = lvis_eval.params + params.max_dets = -1 # No limit on detections per image. + lvis_eval.run() + lvis_eval.print_results() + metrics = {k: v for k, v in lvis_eval.results.items() if k.startswith("AP")} + logger.info(f'mAP_copypaste: {metrics}') + return metrics + + def evaluate(self, size: int) -> dict: + if len(self.results) == 0: + print_log( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.', + logger='current', + level=logging.WARNING) + + all_cats = all_gather_object(self.results) + results = defaultdict(list) + for cats in all_cats: + for cat, cat_anns in cats.items(): + results[cat].extend(cat_anns) + + if is_main_process(): + # cast all tensors in results list to cpu + results = _to_cpu(results) + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results + self.results = {} + return metrics[0] From fa548c8f8dcbcc6787f6faed1248685784b4c9c0 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Tue, 7 Nov 2023 20:24:19 +0800 Subject: [PATCH 3/5] support lvis fix map --- ...in-t_a_fpn_dyhead_pretrain_zershot_lvis.py | 13 ++- mmdet/evaluation/metrics/lvis_metric.py | 84 +++++++++---------- 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py index 153ecdb0992..9127569cc35 100644 --- a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py +++ b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py @@ -1,11 +1,9 @@ _base_ = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' -model = dict( - test_cfg=dict( - max_per_img=300, - chunked_size=40, - ) -) +model = dict(test_cfg=dict( + max_per_img=300, + chunked_size=40, +)) dataset_type = 'LVISV1Dataset' data_root = 'data/coco/' @@ -21,5 +19,6 @@ val_evaluator = dict( _delete_=True, type='LVISFixedAPMetric', - ann_file=data_root + 'annotations/lvis_v1_minival_inserted_image_name.json') + ann_file=data_root + + 'annotations/lvis_v1_minival_inserted_image_name.json') test_evaluator = val_evaluator diff --git a/mmdet/evaluation/metrics/lvis_metric.py b/mmdet/evaluation/metrics/lvis_metric.py index 819537400f7..a861c6ee7b4 100644 --- a/mmdet/evaluation/metrics/lvis_metric.py +++ b/mmdet/evaluation/metrics/lvis_metric.py @@ -1,33 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools +import logging import os.path as osp import tempfile import warnings -from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union, Any -from mmengine.logging import print_log -import torch +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Sequence, Union + import numpy as np +import torch +from mmengine.dist import (all_gather_object, broadcast_object_list, + is_main_process) +from mmengine.evaluator import BaseMetric +from mmengine.evaluator.metric import _to_cpu from mmengine.fileio import get_local_path -from mmengine.logging import MMLogger +from mmengine.logging import MMLogger, print_log from terminaltables import AsciiTable -from mmengine.evaluator.metric import _to_cpu + from mmdet.registry import METRICS from mmdet.structures.mask import encode_mask_results from ..functional import eval_recalls from .coco_metric import CocoMetric -from mmengine.evaluator import BaseMetric -from collections import defaultdict -import logging -from mmengine.dist import all_gather_object, is_main_process, broadcast_object_list try: import lvis if getattr(lvis, '__version__', '0') >= '10.5.3': warnings.warn( - 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', - # noqa: E501 + 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501 UserWarning) from lvis import LVIS, LVISEval, LVISResults except ImportError: @@ -129,8 +129,7 @@ def __init__(self, raise RuntimeError( 'The `file_client_args` is deprecated, ' 'please use `backend_args` instead, please refer to' - 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' - # noqa: E501 + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 ) # if ann_file is not specified, @@ -376,7 +375,8 @@ def _merge_lists(listA, listB, maxN, key): result = [] indA, indB = 0, 0 while (indA < len(listA) or indB < len(listB)) and len(result) < maxN: - if (indB < len(listB)) and (indA >= len(listA) or key(listA[indA]) < key(listB[indB])): + if (indB < len(listB)) and (indA >= len(listA) + or key(listA[indA]) < key(listB[indB])): result.append(listB[indB]) indB += 1 else: @@ -417,9 +417,7 @@ def __init__(self, ann_file, backend_args=self.backend_args) as local_path: self._lvis_api = LVIS(local_path) - # handle dataset lazy init - self.cat_ids = None - self.img_ids = None + self.cat_ids = self._lvis_api.get_cat_ids() self.results = {} self.topk = topk @@ -438,7 +436,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: for data_sample in data_samples: pred = data_sample['pred_instances'] xmin, ymin, xmax, ymax = pred['bboxes'].cpu().unbind(1) - boxes = torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1).tolist() + boxes = torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), + dim=1).tolist() scores = pred['scores'].cpu().numpy() labels = pred['labels'].cpu().numpy() @@ -446,28 +445,25 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: if len(boxes) == 0: continue - cur_results.extend( - [ - { - "image_id": data_sample['img_id'], - "category_id": labels[k], - "bbox": box, - "score": scores[k], - } - for k, box in enumerate(boxes) - ] - ) + cur_results.extend([{ + 'image_id': data_sample['img_id'], + 'category_id': self.cat_ids[labels[k]], + 'bbox': box, + 'score': scores[k], + } for k, box in enumerate(boxes)]) by_cat = defaultdict(list) for ann in cur_results: - by_cat[ann["category_id"]].append(ann) + by_cat[ann['category_id']].append(ann) for cat, cat_anns in by_cat.items(): if cat not in self.results: self.results[cat] = [] - cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk] - self.results[cat] = _merge_lists(self.results[cat], cur, self.topk, key=lambda x: x["score"]) + cur = sorted( + cat_anns, key=lambda x: x['score'], reverse=True)[:self.topk] + self.results[cat] = _merge_lists( + self.results[cat], cur, self.topk, key=lambda x: x['score']) def compute_metrics(self, results: dict) -> dict: logger: MMLogger = MMLogger.get_current_instance() @@ -478,25 +474,27 @@ def compute_metrics(self, results: dict) -> dict: for cat, cat_anns in results.items(): if len(cat_anns) < self.topk: missing_dets_cats.add(cat) - new_results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]) + new_results.extend( + sorted(cat_anns, key=lambda x: x['score'], + reverse=True)[:self.topk]) if missing_dets_cats: logger.info( - f"\n===\n" - f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n" - f"Outputting {self.topk} detections for each class will improve AP further.\n" - f"If using detectron2, please use the lvdevil/infer_topk.py script to " - f"output a results file with {self.topk} detections for each class.\n" - f"===" - ) + f'\n===\n' + f'{len(missing_dets_cats)} classes had less than {self.topk} ' + f'detections!\n Outputting {self.topk} detections for each ' + f'class will improve AP further.\n ===') new_results = LVISResults(self._lvis_api, new_results, max_dets=-1) - lvis_eval = LVISEval(self._lvis_api, new_results, iou_type="bbox") + lvis_eval = LVISEval(self._lvis_api, new_results, iou_type='bbox') params = lvis_eval.params params.max_dets = -1 # No limit on detections per image. lvis_eval.run() lvis_eval.print_results() - metrics = {k: v for k, v in lvis_eval.results.items() if k.startswith("AP")} + metrics = { + k: v + for k, v in lvis_eval.results.items() if k.startswith('AP') + } logger.info(f'mAP_copypaste: {metrics}') return metrics From 1a7cc5c222c82e5142719bcedce1ef2c4c8dbb92 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Wed, 8 Nov 2023 12:57:51 +0800 Subject: [PATCH 4/5] add more configs --- ...win-l_fpn_dyhead_pretrain_zeroshot_lvis.py | 12 ++++++++++ ..._fpn_dyhead_pretrain_zeroshot_mini-lvis.py | 12 ++++++++++ ...n-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py | 24 +++++++++++++++++++ ...fpn_dyhead_pretrain_zeroshot_mini-lvis.py} | 3 ++- ...-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py | 3 +++ ..._fpn_dyhead_pretrain_zeroshot_mini-lvis.py | 3 +++ 6 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py rename configs/glip/{glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py => lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py} (87%) create mode 100644 configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py create mode 100644 configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py diff --git a/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..1f79e447d3f --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,12 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py' + +model = dict( + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + drop_path_rate=0.4, + ), + neck=dict(in_channels=[384, 768, 1536]), + bbox_head=dict(early_fuse=True, num_dyhead_blocks=8)) diff --git a/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py new file mode 100644 index 00000000000..13f1a69082b --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -0,0 +1,12 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py' + +model = dict( + backbone=dict( + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + drop_path_rate=0.4, + ), + neck=dict(in_channels=[384, 768, 1536]), + bbox_head=dict(early_fuse=True, num_dyhead_blocks=8)) diff --git a/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..4d526d59008 --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,24 @@ +_base_ = '../glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' + +model = dict(test_cfg=dict( + max_per_img=300, + chunked_size=40, +)) + +dataset_type = 'LVISV1Dataset' +data_root = 'data/coco/' + +val_dataloader = dict( + dataset=dict( + data_root=data_root, + type=dataset_type, + ann_file='annotations/lvis_od_val.json', + data_prefix=dict(img=''))) +test_dataloader = val_dataloader + +# numpy < 1.24.0 +val_evaluator = dict( + _delete_=True, + type='LVISFixedAPMetric', + ann_file=data_root + 'annotations/lvis_od_val.json') +test_evaluator = val_evaluator diff --git a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py similarity index 87% rename from configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py rename to configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py index 9127569cc35..70a57a3f581 100644 --- a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_zershot_lvis.py +++ b/configs/glip/lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -1,4 +1,4 @@ -_base_ = 'glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' +_base_ = '../glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py' model = dict(test_cfg=dict( max_per_img=300, @@ -16,6 +16,7 @@ data_prefix=dict(img=''))) test_dataloader = val_dataloader +# numpy < 1.24.0 val_evaluator = dict( _delete_=True, type='LVISFixedAPMetric', diff --git a/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py new file mode 100644 index 00000000000..6dc712b3bcb --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py' + +model = dict(bbox_head=dict(early_fuse=True)) diff --git a/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py new file mode 100644 index 00000000000..3babb91101a --- /dev/null +++ b/configs/glip/lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_mini-lvis.py @@ -0,0 +1,3 @@ +_base_ = './glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_mini-lvis.py' + +model = dict(bbox_head=dict(early_fuse=True)) From 8e80c13ee0eb3960f0d72b1dcd05a1b047ad24d4 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Thu, 9 Nov 2023 14:21:38 +0800 Subject: [PATCH 5/5] add result --- configs/glip/README.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/configs/glip/README.md b/configs/glip/README.md index 1252d922ac8..c45cb6dbb6e 100644 --- a/configs/glip/README.md +++ b/configs/glip/README.md @@ -56,7 +56,7 @@ model.save_pretrained("your path/bert-base-uncased") tokenizer.save_pretrained("your path/bert-base-uncased") ``` -## Results and Models +## COCO Results and Models | Model | Zero-shot or Finetune | COCO mAP | Official COCO mAP | Pre-Train Data | Config | Download | | :--------: | :-------------------: | :------: | ----------------: | :------------------------: | :---------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | @@ -78,3 +78,24 @@ Note: 3. Taking the GLIP-T(A) model as an example, I trained it twice using the official code, and the fine-tuning mAP were 52.5 and 52.6. Therefore, the mAP we achieved in our reproduction is higher than the official results. The main reason is that we modified the `weight_decay` parameter. 4. Our experiments revealed that training for 24 epochs leads to overfitting. Therefore, we chose the best-performing model. If users want to train on a custom dataset, it is advisable to shorten the number of epochs and save the best-performing model. 5. Due to the official absence of fine-tuning hyperparameters for the GLIP-L model, we have not yet reproduced the official accuracy. I have found that overfitting can also occur, so it may be necessary to consider custom modifications to data augmentation and model enhancement. Given the high cost of training, we have not conducted any research on this matter at the moment. + +## LVIS Results + +| Model | Official | MiniVal APr | MiniVal APc | MiniVal APf | MiniVal AP | Val1.0 APr | Val1.0 APc | Val1.0 APf | Val1.0 AP | Pre-Train Data | Config | Download | +| :--------: | :------: | :---------: | :---------: | :---------: | :--------: | :--------: | :--------: | :--------: | :-------: | :------------------------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------------------------: | +| GLIP-T (A) | ✔ | | | | | | | | | O365 | [config](lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | +| GLIP-T (A) | | 12.1 | 15.5 | 25.8 | 20.2 | 6.2 | 10.9 | 22.8 | 14.7 | O365 | [config](lvis/glip_atss_swin-t_a_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth) | +| GLIP-T (B) | ✔ | | | | | | | | | O365 | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | +| GLIP-T (B) | | 8.6 | 13.9 | 26.0 | 19.3 | 4.6 | 9.8 | 22.6 | 13.9 | O365 | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_b_mmdet-6dfbd102.pth) | +| GLIP-T (C) | ✔ | 14.3 | 19.4 | 31.1 | 24.6 | | | | | O365,GoldG | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | +| GLIP-T (C) | | 14.4 | 19.8 | 31.9 | 25.2 | 8.3 | 13.2 | 28.1 | 18.2 | O365,GoldG | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_c_mmdet-2fc427dd.pth) | +| GLIP-T | ✔ | | | | | | | | | O365,GoldG,CC3M,SBU | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | +| GLIP-T | | 18.1 | 21.2 | 33.1 | 26.7 | 10.8 | 14.7 | 29.0 | 19.6 | O365,GoldG,CC3M,SBU | [config](lvis/glip_atss_swin-t_bc_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_mmdet-c24ce662.pth) | +| GLIP-L | ✔ | 29.2 | 34.9 | 42.1 | 37.9 | | | | | FourODs,GoldG,CC3M+12M,SBU | [config](lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | +| GLIP-L | | 27.9 | 33.7 | 39.7 | 36.1 | | | | | FourODs,GoldG,CC3M+12M,SBU | [config](lvis/glip_atss_swin-l_fpn_dyhead_pretrain_zeroshot_lvis.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/glip/glip_l_mmdet-abfe026b.pth) | + +Note: + +1. The above are zero-shot evaluation results. +2. The evaluation metric we used is LVIS FixAP. For specific details, please refer to [Evaluating Large-Vocabulary Object Detectors: The Devil is in the Details](https://arxiv.org/pdf/2102.01066.pdf). +3. We found that the performance on small models is better than the official results, but it is lower on large models. This is mainly due to the incomplete alignment of the GLIP post-processing.