From 49ab4892a87753bd4895178d5cbdbb0551276ebe Mon Sep 17 00:00:00 2001 From: lgvaz Date: Mon, 2 Nov 2020 16:04:09 -0300 Subject: [PATCH] fixes pretrained in rcnn models --- CHANGELOG.md | 1 - icevision/models/rcnn/faster_rcnn/model.py | 8 ++------ icevision/models/rcnn/mask_rcnn/model.py | 6 +----- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cdaab599..5544e6054 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added -- Added `pretrained_backbone: bool = True` argument to both faster_rcnn and mask_rcnn `model()` methods. (#520) ### Changed ### Deleted diff --git a/icevision/models/rcnn/faster_rcnn/model.py b/icevision/models/rcnn/faster_rcnn/model.py index b6cd151cb..8aa492ca1 100644 --- a/icevision/models/rcnn/faster_rcnn/model.py +++ b/icevision/models/rcnn/faster_rcnn/model.py @@ -14,7 +14,6 @@ def model( backbone: Optional[nn.Module] = None, remove_internal_transforms: bool = True, pretrained: bool = True, - pretrained_backbone=True, **faster_rcnn_kwargs ) -> nn.Module: """FasterRCNN model implemented by torchvision. @@ -25,10 +24,9 @@ def model( remove_internal_transforms: The torchvision model internally applies transforms like resizing and normalization, but we already do this at the `Dataset` level, so it's safe to remove those internal transforms. - pretrained: Argument passed to `maskrcnn_resnet50_fpn` if `backbone is None`. + pretrained: Argument passed to `fastercnn_resnet50_fpn` if `backbone is None`. By default it is set to True: this is generally used when training a new model (transfer learning). `pretrained = False` is used during inference (prediction) for cases where the users have their own pretrained weights. - `pretrained_backbone = False` is used during inference (prediction) for cases where the users have their own pretrained backbone weights. **faster_rcnn_kwargs: Keyword arguments that internally are going to be passed to `torchvision.models.detection.faster_rcnn.FastRCNN`. @@ -37,9 +35,7 @@ def model( """ if backbone is None: model = fasterrcnn_resnet50_fpn( - pretrained=True, - pretrained_backbone=pretrained_backbone, - **faster_rcnn_kwargs + pretrained=pretrained, pretrained_backbone=pretrained, **faster_rcnn_kwargs ) in_features = model.roi_heads.box_predictor.cls_score.in_features diff --git a/icevision/models/rcnn/mask_rcnn/model.py b/icevision/models/rcnn/mask_rcnn/model.py index 12394616c..73c3f2c3a 100644 --- a/icevision/models/rcnn/mask_rcnn/model.py +++ b/icevision/models/rcnn/mask_rcnn/model.py @@ -15,7 +15,6 @@ def model( backbone: Optional[nn.Module] = None, remove_internal_transforms: bool = True, pretrained: bool = True, - pretrained_backbone=True, **mask_rcnn_kwargs ) -> nn.Module: """MaskRCNN model implemented by torchvision. @@ -29,7 +28,6 @@ def model( pretrained: Argument passed to `maskrcnn_resnet50_fpn` if `backbone is None`. By default it is set to True: this is generally used when training a new model (transfer learning). `pretrained = False` is used during inference (prediction) for cases where the users have their own pretrained weights. - `pretrained_backbone = False` is used during inference (prediction) for cases where the users have their own pretrained backbone weights. **mask_rcnn_kwargs: Keyword arguments that internally are going to be passed to `torchvision.models.detection.mask_rcnn.MaskRCNN`. @@ -38,9 +36,7 @@ def model( """ if backbone is None: model = maskrcnn_resnet50_fpn( - pretrained=pretrained, - pretrained_backbone=pretrained_backbone, - **mask_rcnn_kwargs + pretrained=pretrained, pretrained_backbone=pretrained, **mask_rcnn_kwargs ) in_features_box = model.roi_heads.box_predictor.cls_score.in_features