-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reft method #8723
Add reft method #8723
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8723 +/- ##
===========================================
- Coverage 55.44% 54.77% -0.67%
===========================================
Files 631 645 +14
Lines 98542 100066 +1524
===========================================
+ Hits 54632 54812 +180
- Misses 43910 45254 +1344 ☔ View full report in Codecov by Sentry. |
llm/utils/argument.py
Outdated
share_weights: bool = field(default=True, metadata={"help": "Flag indicating whether to share weights."}) | ||
greedy_decoding: bool = field(default=True, metadata={"help": "Flag indicating whether to use greedy decoding."}) | ||
temperature: float = field(default=None, metadata={"help": "Temperature parameter for decoding."}) | ||
num_hidden_layers: int = field(default=32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REFT如果有比较多的参数建议另开一个REFTArgument,提交的pr建议筛选出必要的argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
llm/utils/compute_metrics.py
Outdated
return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn) | ||
|
||
|
||
def compute_metrics_reft( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个compute_metrics如果不是通用必要的建议删除
llm/utils/data.py
Outdated
|
||
|
||
# paddle version label 错开 | ||
class LoReftSupervisedDataset(ReftDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新开这个dataset的原因是什么,原有的无法复用吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token
paddlenlp/__init__.py
Outdated
@@ -45,6 +45,7 @@ | |||
peft, | |||
prompt, | |||
quantization, | |||
reft, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reft建议放在peft目录下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddlenlp/reft/pareft/layers.py
Outdated
print("n,m", n, m) | ||
|
||
# weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Orthogonal()) | ||
# linear = paddle.nn.Linear(10, 15, weight_attr=weight_attr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉无关的comment还有在pr里描述一下哪些是reft方法,哪些是你的方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddlenlp/reft/pareft/dataset.py
Outdated
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
IGNORE_INDEX = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么reft需要单独的dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token
from dataclasses import dataclass | ||
from typing import Dict, Sequence | ||
|
||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reft为什么需要单独trainer?
PR types
New feature
PR changes
Add reft in paddlenlp/peft/reft
reft
├── pareft
│ ├── config.py reft配置文件,继承pavenv.config
│ ├── dataset.py reft数据处理,reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token
│ ├── init.py
│ ├── interventions.py 干预网络
│ ├── reft_model.py 创建reft模型,继承pavenv.interventableModel
│ ├── reft_trainer.py 重写compute_loss方法,reft方法需要根据配置中的position参数干预对应位置token的隐藏表示
│ └── utils.py 工具类
└── pavenv
├── init.py
└── models
├── basic_utils.py 基础的工具类
├── configuration_intervenable_model.py 创建干预方法的配置
├── constants.py 常量
├── init.py
├── intervenable_base.py 这个是方法实现的主要类,在该类中模型添加orward_post_hook,在前向传播过程中hook中提取干预位置的向量,将提取的向量输入干预模型,将干预模型的结果替换对应位置的向量
├── intervenable_modelcard.py 模型配置的信息
├── interventions.py 所有干预网络的父类
├── llama
│ └── modelings_intervenable_llama.py llama模型的基础配置
└── modeling_utils.py 模型的一些工具类
Description