Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reft method #8723

Closed

Conversation

TranscenderNing
Copy link
Contributor

@TranscenderNing TranscenderNing commented Jul 7, 2024

PR types

New feature

PR changes

Add reft in paddlenlp/peft/reft
reft
├── pareft
│ ├── config.py reft配置文件,继承pavenv.config
│ ├── dataset.py reft数据处理,reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token
│ ├── init.py
│ ├── interventions.py 干预网络
│ ├── reft_model.py 创建reft模型,继承pavenv.interventableModel
│ ├── reft_trainer.py 重写compute_loss方法,reft方法需要根据配置中的position参数干预对应位置token的隐藏表示
│ └── utils.py 工具类
└── pavenv
├── init.py
└── models
├── basic_utils.py 基础的工具类
├── configuration_intervenable_model.py 创建干预方法的配置
├── constants.py 常量
├── init.py
├── intervenable_base.py 这个是方法实现的主要类,在该类中模型添加orward_post_hook,在前向传播过程中hook中提取干预位置的向量,将提取的向量输入干预模型,将干预模型的结果替换对应位置的向量
├── intervenable_modelcard.py 模型配置的信息
├── interventions.py 所有干预网络的父类
├── llama
│ └── modelings_intervenable_llama.py llama模型的基础配置
└── modeling_utils.py 模型的一些工具类

Description

Copy link

paddle-bot bot commented Jul 7, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Jul 7, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Jul 7, 2024

Codecov Report

Attention: Patch coverage is 19.14043% with 1618 lines in your changes missing coverage. Please review.

Project coverage is 54.77%. Comparing base (ee4944e) to head (2975154).

Current head 2975154 differs from pull request most recent head fdfb8c5

Please upload reports for the commit fdfb8c5 to get more accurate results.

Files Patch % Lines
paddlenlp/reft/pavenv/models/intervenable_base.py 9.88% 711 Missing ⚠️
paddlenlp/reft/pareft/dataset.py 18.29% 259 Missing ⚠️
paddlenlp/reft/pavenv/models/modeling_utils.py 14.22% 193 Missing ⚠️
paddlenlp/reft/pavenv/models/intervention_utils.py 15.65% 97 Missing ⚠️
paddlenlp/reft/pavenv/models/interventions.py 37.81% 74 Missing ⚠️
.../pavenv/models/configuration_intervenable_model.py 12.98% 67 Missing ⚠️
paddlenlp/reft/pavenv/models/basic_utils.py 26.31% 56 Missing ⚠️
paddlenlp/reft/pareft/reft_trainer.py 32.92% 55 Missing ⚠️
paddlenlp/reft/pareft/interventions.py 25.86% 43 Missing ⚠️
paddlenlp/reft/pareft/reft_model.py 25.71% 26 Missing ⚠️
... and 6 more
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8723      +/-   ##
===========================================
- Coverage    55.44%   54.77%   -0.67%     
===========================================
  Files          631      645      +14     
  Lines        98542   100066    +1524     
===========================================
+ Hits         54632    54812     +180     
- Misses       43910    45254    +1344     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

share_weights: bool = field(default=True, metadata={"help": "Flag indicating whether to share weights."})
greedy_decoding: bool = field(default=True, metadata={"help": "Flag indicating whether to use greedy decoding."})
temperature: float = field(default=None, metadata={"help": "Temperature parameter for decoding."})
num_hidden_layers: int = field(default=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REFT如果有比较多的参数建议另开一个REFTArgument,提交的pr建议筛选出必要的argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn)


def compute_metrics_reft(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个compute_metrics如果不是通用必要的建议删除



# paddle version label 错开
class LoReftSupervisedDataset(ReftDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新开这个dataset的原因是什么,原有的无法复用吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token

@@ -45,6 +45,7 @@
peft,
prompt,
quantization,
reft,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reft建议放在peft目录下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

print("n,m", n, m)

# weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Orthogonal())
# linear = paddle.nn.Linear(10, 15, weight_attr=weight_attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉无关的comment还有在pr里描述一下哪些是reft方法,哪些是你的方法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# See the License for the specific language governing permissions and
# limitations under the License.

IGNORE_INDEX = -100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么reft需要单独的dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reft方法输入会有一个新的intervention_locations字段,表示干预的token的位置,例如f5+l5表示干预输入的前5个tokne和后5个token

from dataclasses import dataclass
from typing import Dict, Sequence

import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reft为什么需要单独trainer?

paddlenlp/reft/pavenv/models/__init__.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants