Skip to content

Commit

Permalink
adapter fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sileod committed Apr 6, 2023
1 parent 62c890f commit 78f945c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def __init__(self, config, classifiers=None, Z=None, labels_list=[]):
self.classifiers=torch.nn.ModuleList(
[torch.nn.Linear(config.hidden_size,size) for size in config.classifiers_size]
) if classifiers==None else classifiers

self.config=self.config.from_dict(
{**self.config.to_dict(),
'labels_list':labels_list}
)
def adapt_model_to_task(self, model, task_name):
task_index=self.config.tasks.index(task_name)
last_linear(model).weight = last_linear(self.classifiers[task_index]).weight
setattr(model,search_module(model,'linear',mode='class')[-1], self.classifiers[task_index])
return model
def _init_weights(*args):
pass
Expand Down

0 comments on commit 78f945c

Please sign in to comment.