-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
148 lines (117 loc) · 5.74 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import gc
import random
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
from tqdm import tqdm
#MODEL_ID = "Mistral-7B-Instruct-v0.2"
MODEL_ID = "miqu-1-70b-sf"
# More samples can help find the direction better.
NUM_PROMPT_SAMPLES = 32
# Used to skip the first and last layers for the modifications.
SKIP_BEGIN_LAYERS = 1 # Don't mess with the first layer.
SKIP_END_LAYERS = 0
# The layer we will use for the refusal_dir calculation will be floor(LAYER_FRACTION_TO_USE * model.layers).
LAYER_FRACTION_TO_USE = 0.6
# Use a negative scale_factor to "induce" and a positive scale_factor of < 1 to "ablate" less.
SCALE_FACTOR = 1.0
torch.inference_mode()
torch.set_default_device("cpu")
torch.set_grad_enabled(False)
# Load the model on the GPU in quantized type if we can.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
low_cpu_mem_usage=True,
device_map='auto'
)
model.requires_grad_(False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
layer_idx = int(len(model.model.layers) * LAYER_FRACTION_TO_USE)
print("Layer index for refusal direction: " + str(layer_idx))
with open("harmful.txt", "r") as f:
harmful = f.readlines()
with open("harmless.txt", "r") as f:
harmless = f.readlines()
harmful_instructions = random.sample(harmful, min(NUM_PROMPT_SAMPLES, len(harmful)))
harmless_instructions = random.sample(harmless, min(NUM_PROMPT_SAMPLES, len(harmless)))
harmful_toks = [
tokenizer.apply_chat_template(conversation=[{"role": "user", "content": insn}], add_generation_prompt=True,
return_tensors="pt") for insn in harmful_instructions]
harmless_toks = [
tokenizer.apply_chat_template(conversation=[{"role": "user", "content": insn}], add_generation_prompt=True,
return_tensors="pt") for insn in harmless_instructions]
bar_generate = tqdm(total = len(harmful_instructions) + len(harmless_instructions), desc = "Generating samples")
# Only return the final hidden state of the layer we care about, and use 'cpu' to save VRAM.
def generate(toks):
output = model.generate(
toks.to(model.device),
use_cache=False,
max_new_tokens=1,
return_dict_in_generate=True,
output_hidden_states=True,
pad_token_id=tokenizer.eos_token_id
)
bar_generate.update(n=1)
return output.hidden_states[0][layer_idx][:, -1, :].to('cpu') # Final hidden state = -1.
harmful_hidden = [generate(toks) for toks in harmful_toks]
harmless_hidden = [generate(toks) for toks in harmless_toks]
bar_generate.close()
harmful_mean = torch.stack(harmful_hidden).mean(dim=0)
harmless_mean = torch.stack(harmless_hidden).mean(dim=0)
refusal_dir = harmful_mean - harmless_mean
refusal_dir = refusal_dir.squeeze() / refusal_dir.norm()
torch.save(refusal_dir, MODEL_ID.replace("/", "_") + "_refusal_dir.pt")
# Free memory
del model
gc.collect()
torch.cuda.empty_cache()
# Reload the model in CPU memory with bfloat16 data type
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map='cpu'
)
model.requires_grad_(False)
# Make sure it's on the 'cpu' device.
if refusal_dir.device != model.device:
refusal_dir = refusal_dir.to(model.device)
# Get the language model component and check it's as expected.
lm_model = model.model
assert hasattr(lm_model, 'layers'), "The model does not have the expected structure."
# Check the ranges are valid.
num_layers = len(lm_model.layers)
assert SKIP_BEGIN_LAYERS >= 0, "SKIP_BEGIN_LAYERS must be >= 0."
assert SKIP_END_LAYERS >= 0, "SKIP_END_LAYERS must be >= 0."
assert SKIP_BEGIN_LAYERS + SKIP_END_LAYERS < num_layers, "SKIP_BEGIN_LAYERS + SKIP_END_LAYERS must be < num_layers."
bar_layers = tqdm(total= (num_layers - (SKIP_BEGIN_LAYERS + SKIP_END_LAYERS)) * 2, desc = "Modifying tensors")
# Cast any ops performed on CPU up to float32... If you have newer CPU might be able to use bfloat16 for this.
# NOTE: Use a negative scale_factor to "induce" and a positive scale_factor of < 1 to "ablate" less.
def modify_tensor(tensor_data, refusal_dir, scale_factor: float = 1.0):
assert scale_factor <= 1.0, "Using a scale_factor of > 1 doesn't make sense..."
tensor_float32 = tensor_data.to(torch.float32)
refusal_dir_float32 = refusal_dir.to(torch.float32)
tensor_float32 -= scale_factor * torch.matmul(torch.outer(refusal_dir_float32, refusal_dir_float32), tensor_float32)
tensor_modified = tensor_float32.to(torch.bfloat16)
bar_layers.update(1)
return torch.nn.Parameter(tensor_modified)
# Modify the 'self_attn.o_proj.weight' and 'mlp.down_proj.weight' in each chosen layer.
# NOTE: These tensors names are speific to "llama" and may need changing.
# - See here for others: https://github.com/arcee-ai/mergekit/tree/main/mergekit/_data/architectures
for layer_idx in range(SKIP_BEGIN_LAYERS, num_layers - SKIP_END_LAYERS):
lm_model.layers[layer_idx].self_attn.o_proj.weight = modify_tensor(
lm_model.layers[layer_idx].self_attn.o_proj.weight.data, refusal_dir, SCALE_FACTOR
)
lm_model.layers[layer_idx].mlp.down_proj.weight = modify_tensor(
lm_model.layers[layer_idx].mlp.down_proj.weight.data, refusal_dir, SCALE_FACTOR
)
bar_layers.close()
# Save the modified model and original tokenizer
print("Saving modified model (with original tokenizer)...")
#model.save_pretrained("Mistral-7B-Instruct-v0.2-fixed")
#tokenizer.save_pretrained("Mistral-7B-Instruct-v0.2-fixed")
model.save_pretrained("miqu-1-70b-sf-fixed")
tokenizer.save_pretrained("miqu-1-70b-sf-fixed")