-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_pipeline.py
51 lines (38 loc) · 1.27 KB
/
inference_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Fine-tuned model name
model_name = "./7b1"
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline,
logging,
)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
# use the gpu
device_map= "auto"
)
# Load the tokenizer from the model (llama2)
tokenizer = AutoTokenizer.from_pretrained(model_name)#, trust_remote_code=True, use_fast=False)
model.use_cache = True
model.eval()
logging.set_verbosity(logging.CRITICAL)
print("hello")
def format_instruction(sample):
return f"""### Instruction:
You are a coding assistant that will write a Solution to resolve the following Task:
### Task:
{sample}
### Solution:
""".strip()
while(True):
prompt = input("Enter Your Prompt: ")
if(prompt=="exit"):
break
final = format_instruction(prompt)
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, torch_dtype = torch.float16, repetition_penalty = 1.1)
result = pipe(f"{final}", do_sample = True, top_p = 0.5, temperature = 0.5, top_k = 10 ,num_return_sequences = 1, eos_token_id = tokenizer.eos_token_id,max_length = 2048)
for output in result:
print(output['generated_text'])