diff --git a/src/tasknet/models.py b/src/tasknet/models.py index aa44cf5..80d0085 100755 --- a/src/tasknet/models.py +++ b/src/tasknet/models.py @@ -44,10 +44,13 @@ def __init__(self, config, classifiers=None, Z=None, labels_list=[]): self.classifiers=torch.nn.ModuleList( [torch.nn.Linear(config.hidden_size,size) for size in config.classifiers_size] ) if classifiers==None else classifiers - + self.config=self.config.from_dict( + {**self.config.to_dict(), + 'labels_list':labels_list} + ) def adapt_model_to_task(self, model, task_name): task_index=self.config.tasks.index(task_name) - last_linear(model).weight = last_linear(self.classifiers[task_index]).weight + setattr(model,search_module(model,'linear',mode='class')[-1], self.classifiers[task_index]) return model def _init_weights(*args): pass