Skip to content

Commit

Permalink
fixes pretrained in rcnn models
Browse files Browse the repository at this point in the history
  • Loading branch information
lgvaz committed Nov 2, 2020
1 parent 5aa91d5 commit 49ab489
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]

### Added
- Added `pretrained_backbone: bool = True` argument to both faster_rcnn and mask_rcnn `model()` methods. (#520)
### Changed
### Deleted

Expand Down
8 changes: 2 additions & 6 deletions icevision/models/rcnn/faster_rcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def model(
backbone: Optional[nn.Module] = None,
remove_internal_transforms: bool = True,
pretrained: bool = True,
pretrained_backbone=True,
**faster_rcnn_kwargs
) -> nn.Module:
"""FasterRCNN model implemented by torchvision.
Expand All @@ -25,10 +24,9 @@ def model(
remove_internal_transforms: The torchvision model internally applies transforms
like resizing and normalization, but we already do this at the `Dataset` level,
so it's safe to remove those internal transforms.
pretrained: Argument passed to `maskrcnn_resnet50_fpn` if `backbone is None`.
pretrained: Argument passed to `fastercnn_resnet50_fpn` if `backbone is None`.
By default it is set to True: this is generally used when training a new model (transfer learning).
`pretrained = False` is used during inference (prediction) for cases where the users have their own pretrained weights.
`pretrained_backbone = False` is used during inference (prediction) for cases where the users have their own pretrained backbone weights.
**faster_rcnn_kwargs: Keyword arguments that internally are going to be passed to
`torchvision.models.detection.faster_rcnn.FastRCNN`.
Expand All @@ -37,9 +35,7 @@ def model(
"""
if backbone is None:
model = fasterrcnn_resnet50_fpn(
pretrained=True,
pretrained_backbone=pretrained_backbone,
**faster_rcnn_kwargs
pretrained=pretrained, pretrained_backbone=pretrained, **faster_rcnn_kwargs
)

in_features = model.roi_heads.box_predictor.cls_score.in_features
Expand Down
6 changes: 1 addition & 5 deletions icevision/models/rcnn/mask_rcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def model(
backbone: Optional[nn.Module] = None,
remove_internal_transforms: bool = True,
pretrained: bool = True,
pretrained_backbone=True,
**mask_rcnn_kwargs
) -> nn.Module:
"""MaskRCNN model implemented by torchvision.
Expand All @@ -29,7 +28,6 @@ def model(
pretrained: Argument passed to `maskrcnn_resnet50_fpn` if `backbone is None`.
By default it is set to True: this is generally used when training a new model (transfer learning).
`pretrained = False` is used during inference (prediction) for cases where the users have their own pretrained weights.
`pretrained_backbone = False` is used during inference (prediction) for cases where the users have their own pretrained backbone weights.
**mask_rcnn_kwargs: Keyword arguments that internally are going to be passed to
`torchvision.models.detection.mask_rcnn.MaskRCNN`.
Expand All @@ -38,9 +36,7 @@ def model(
"""
if backbone is None:
model = maskrcnn_resnet50_fpn(
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
**mask_rcnn_kwargs
pretrained=pretrained, pretrained_backbone=pretrained, **mask_rcnn_kwargs
)

in_features_box = model.roi_heads.box_predictor.cls_score.in_features
Expand Down

0 comments on commit 49ab489

Please sign in to comment.