Skip to content

Commit

Permalink
[BUGFIX] Match anchor and stride dims
Browse files Browse the repository at this point in the history
  • Loading branch information
Sefa Burak Okcu committed Dec 12, 2023
1 parent a7576b6 commit 9a35bb6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

model_definition['anchors'] = (model.module.model[-1].anchors*model.module.model[-1].stride).cpu().reshape(model.module.model[-1].anchors.shape[0], -1).tolist() if hasattr(model, 'module') else (model.model[-1].anchors*model.model[-1].stride).cpu().reshape(model.model[-1].anchors.shape[0], -1).tolist()
model_definition['anchors'] = (model.module.model[-1].anchors*model.module.model[-1].stride).cpu().reshape(model.module.model[-1].anchors.shape[0], -1).tolist()\
if hasattr(model, 'module') else (model.model[-1].anchors*model.model[-1].stride[:, None, None]).cpu().reshape(model.model[-1].anchors.shape[0], -1).tolist()
with open(save_dir / 'model.yaml', 'w') as f:
yaml.safe_dump(model_definition, f, sort_keys=False)

Expand Down Expand Up @@ -642,7 +643,7 @@ def run(**kwargs):
if __name__ == "__main__":
opt = parse_opt()

opt.data = 'data/widerface.yaml'
#opt.data = 'data/widerface.yaml'
opt.noval = True
opt.noautoanchor = True

Expand Down

0 comments on commit 9a35bb6

Please sign in to comment.