Now we act choice shuffle set By shuffling the order of answer choices for each test question, multiple variants of the same question are created. The LLM is then asked to provide these variants, along with corresponding few-trial examples, to generate reasoning steps and an answer for each variant. Finally, we perform a majority vote on the predictions for all variants and select the final prediction.
The code related to this implementation can be found at this github repository link.
We use the MedQA dataset (6) to implement and evaluate Medprompt. We first define helper functions to parse the jsonl files.
def write_jsonl_file(file_path, dict_list):
"""
Write a list of dictionaries to a JSON Lines file.Args:
- file_path (str): The path to the file where the data will be written.
- dict_list (list): A list of dictionaries to write to the file.
"""
with open(file_path, 'w') as file:
for dictionary in dict_list:
json_line = json.dumps(dictionary)
file.write(json_line + '\n')
def read_jsonl_file(file_path):
"""
Parses a JSONL (JSON Lines) file and returns a list of dictionaries.
Args:
file_path (str): The path to the JSONL file to be read.
Returns:
list of dict: A list where each element is a dictionary representing
a JSON object from the file.
"""
jsonl_lines = ()
with open(file_path, 'r', encoding="utf-8") as file:
for line in file:
json_object = json.loads(line)
jsonl_lines.append(json_object)
return jsonl_lines
Implementation of self-generated CoT
For our implementation, we use the MedQA training set. We implement a zero-shot test response prompt and process all training questions. We use GPT-4o In our implementation, we generate the CoT and the corresponding response for each question. We define a prompt that is based on the template provided in the Medprompt document.
system_prompt = """You are an expert medical professional. You are provided with a medical question with multiple answer choices.
Your goal is to think through the question carefully and explain your reasoning step by step before selecting the final answer.
Respond only with the reasoning steps and answer as specified below.
Below is the format for each question and answer:Input:
## Question: {{question}}
{{answer_choices}}
Output:
## Answer
(model generated chain of thought explanation)
Therefore, the answer is (final model answer (e.g. A,B,C,D))"""
def build_few_shot_prompt(system_prompt, question, examples, include_cot=True):
"""
Builds the zero-shot prompt.Args:
system_prompt (str): Task Instruction for the LLM
content (dict): The content for which to create a query, formatted as
required by `create_query`.
Returns:
list of dict: A list of messages, including a system message defining
the task and a user message with the input question.
"""
messages = ({"role": "system", "content": system_prompt})
for elem in examples:
messages.append({"role": "user", "content": create_query(elem)})
if include_cot:
messages.append({"role": "assistant", "content": format_answer(elem("cot"), elem("answer_idx"))})
else:
answer_string = f"""## Answer\nTherefore, the answer is {elem("answer_idx")}"""
messages.append({"role": "assistant", "content": answer_string})
messages.append({"role": "user", "content": create_query(question)})
return messages
def get_response(messages, model_name, temperature = 0.0, max_tokens = 10):
"""
Obtains the responses/answers of the model through the chat-completions API.
Args:
messages (list of dict): The built messages provided to the API.
model_name (str): Name of the model to access through the API
temperature (float): A value between 0 and 1 that controls the randomness of the output.
A temperature value of 0 ideally makes the model pick the most likely token, making the outputs deterministic.
max_tokens (int): Maximum number of tokens that the model should generate
Returns:
str: The response message content from the model.
"""
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices(0).message.content
We also define helper functions to analyze the reasoning and the final answer option of the LLM response.
def matches_ans_option(s):
"""
Checks if the string starts with the specific pattern 'Therefore, the answer is (A-Z)'.Args:
s (str): The string to be checked.
Returns:
bool: True if the string matches the pattern, False otherwise.
"""
return bool(re.match(r'^Therefore, the answer is (A-Z)', s))
def extract_ans_option(s):
"""
Extracts the answer option (a single capital letter) from the start of the string.
Args:
s (str): The string containing the answer pattern.
Returns:
str or None: The captured answer option if the pattern is found, otherwise None.
"""
match = re.search(r'^Therefore, the answer is ((A-Z))', s)
if match:
return match.group(1) # Returns the captured alphabet
return None
def matches_answer_start(s):
"""
Checks if the string starts with the markdown header '## Answer'.
Args:
s (str): The string to be checked.
Returns:
bool: True if the string starts with '## Answer', False otherwise.
"""
return s.startswith("## Answer")
def validate_response(s):
"""
Validates a multi-line string response that it starts with '## Answer' and ends with the answer pattern.
Args:
s (str): The multi-line string response to be validated.
Returns:
bool: True if the response is valid, False otherwise.
"""
file_content = s.split("\n")
return matches_ans_option(file_content(-1)) and matches_answer_start(s)
def parse_answer(response):
"""
Parses a response that starts with '## Answer', extracting the reasoning and the answer choice.
Args:
response (str): The multi-line string response containing the answer and reasoning.
Returns:
tuple: A tuple containing the extracted CoT reasoning and the answer choice.
"""
split_response = response.split("\n")
assert split_response(0) == "## Answer"
cot_reasoning = "\n".join(split_response(1:-1)).strip()
ans_choice = extract_ans_option(split_response(-1))
return cot_reasoning, ans_choice
Now we process the questions in the MedQA training set. We get the CoT responses and the answers to all the questions and store them in a folder.
train_data = read_jsonl_file("data/phrases_no_exclude_train.jsonl")cot_responses = ()
# os.mkdir("cot_responses")
existing_files = os.listdir("cot_responses/")
for idx, item in enumerate(tqdm(train_data)):
if str(idx) + ".txt" in existing_files:
continue
prompt = build_zero_shot_prompt(system_prompt, item)
try:
response = get_response(prompt, model_name="gpt-4o", max_tokens=500)
cot_responses.append(response)
with open(os.path.join("cot_responses", str(idx) + ".txt"), "w", encoding="utf-8") as f:
f.write(response)
except Exception as e :
print(str(e))
cot_responses.append("")
We now iterate over all the generated answers to check if they are valid and adhere to the prediction format defined in the request. We discard the answers that do not fit the required format. After that, we check the predicted answers against the ground truth for each question and only retain the questions for which the predicted answers match the ground truth.
questions_dict = ()
ctr = 0
for idx, question in enumerate(tqdm(train_data)):
file = open(os.path.join("cot_responses/", str(idx) + ".txt"), encoding="utf-8").read()
if not validate_response(file):
continuecot, pred_ans = parse_answer(file)
dict_elem = {}
dict_elem("idx") = idx
dict_elem("question") = question("question")
dict_elem("answer") = question("answer")
dict_elem("options") = question("options")
dict_elem("cot") = cot
dict_elem("pred_ans") = pred_ans
questions_dict.append(dict_elem)
filtered_questions_dict = ()
for item in tqdm(questions_dict):
pred_ans = item("options")(item("pred_ans"))
if pred_ans == item("answer"):
filtered_questions_dict.append(item)
Implementation of the KNN model
After processing the training set and getting the CoT response for all these questions, we now integrate all the questions using the text-embed-ada-002 from OpenAI.
def get_embedding(text, model="text-embedding-ada-002"):
return client.embeddings.create(input = (text), model=model).data(0).embeddingfor item in tqdm(filtered_questions_dict):
item("embedding") = get_embedding(item("question"))
inv_options_map = {v:k for k,v in item("options").items()}
item("answer_idx") = inv_options_map(item("answer"))
We now train a KNN model using these question embeddings. This acts as a retriever at the time of inference as it helps us retrieve similar data points from the training set which are more similar to the question in the test set.
import numpy as np
from sklearn.neighbors import NearestNeighborsembeddings = np.array((d("embedding") for d in filtered_questions_dict))
indices = list(range(len(filtered_questions_dict)))
knn = NearestNeighbors(n_neighbors=5, algorithm='auto', metric='cosine').fit(embeddings)
Implementation of dynamic few-shot array logic and random selection
We can now perform inferences. We take a sample of 500 questions from the MedQA test set for our evaluation. For each question, we retrieve the 5 most similar questions from the test set using the KNN module, along with their respective CoT reasoning steps and predicted answers. We construct a quick response message using these examples.
For each question, we also shuffled the order of the options 5 times to create different variants. We then used the few-shot system to obtain the predicted answer for each of the variants with the shuffled options.
def shuffle_option_labels(answer_options):
"""
Shuffles the options of the question.Parameters:
answer_options (dict): A dictionary with the options.
Returns:
dict: A new dictionary with the shuffled options.
"""
options = list(answer_options.values())
random.shuffle(options)
labels = (chr(i) for i in range(ord('A'), ord('A') + len(options)))
shuffled_options_dict = {label: option for label, option in zip(labels, options)}
return shuffled_options_dict
test_samples = read_jsonl_file("final_processed_test_set_responses_medprompt.jsonl")for question in tqdm(test_samples, colour ="green"):
question_variants = ()
prompt_variants = ()
cot_responses = ()
question_embedding = get_embedding(question("question"))
distances, top_k_indices = knn.kneighbors((question_embedding), n_neighbors=5)
top_k_dicts = (filtered_questions_dict(i) for i in top_k_indices(0))
question("outputs") = ()
for idx in range(5):
question_copy = question.copy()
shuffled_options = shuffle_option_labels(question("options"))
inv_map = {v:k for k,v in shuffled_options.items()}
question_copy("options") = shuffled_options
question_copy("answer_idx") = inv_map(question_copy("answer"))
question_variants.append(question_copy)
prompt = build_few_shot_prompt(system_prompt, question_copy, top_k_dicts)
prompt_variants.append(prompt)
for prompt in tqdm(prompt_variants):
response = get_response(prompt, model_name="gpt-4o", max_tokens=500)
cot_responses.append(response)
for question_sample, answer in zip(question_variants, cot_responses):
if validate_response(answer):
cot, pred_ans = parse_answer(answer)
else:
cot = ""
pred_ans = ""
question("outputs").append({"question": question_sample("question"), "options": question_sample("options"), "cot": cot, "pred_ans": question_sample("options").get(pred_ans, "")})
We now evaluate the results of Medprompt on the test set. For each question, we have five predictions generated through ensemble logic. We take the mode, or the most frequently occurring prediction, for each question as the final prediction and evaluate the performance. Two extreme cases are possible here:
- Two different answer options are predicted twice each, with no clear winner.
- There is an error with the generated response, which means we do not have a response option provided.
In both extreme cases, we consider that the question has been answered incorrectly by the LLM.
def find_mode_string_list(string_list):
"""
Finds the most frequently occurring strings.Parameters:
string_list (list of str): A list of strings.
Returns:
list of str or None: A list containing the most frequent string(s) from the input list.
Returns None if the input list is empty.
"""
if not string_list:
return None
string_counts = Counter(string_list)
max_freq = max(string_counts.values())
mode_strings = (string for string, count in string_counts.items() if count == max_freq)
return mode_strings
ctr = 0
for item in test_samples:
pred_ans = (x("pred_ans") for x in item("outputs"))
freq_ans = find_mode_string_list(pred_ans)
if len(freq_ans) > 1:
final_prediction = ""
else:
final_prediction = freq_ans(0)
if final_prediction == item("answer"):
ctr +=1
print(ctr / len(test_samples))
We evaluate the performance of Medprompt with GPT-4o in terms of accuracy on the MedQA test subset. Additionally, we compare the performance of Zero-shot prompting, Random Few-Shot prompting, and Random Few-Shot prompting with CoT prompting.
We observe that the CoT of Medprompt and few random shots outperforms the zero and few shots benchmarks. However, surprisingly, we notice that the CoT of few random shots outperforms Medprompt. This could be due to a couple of reasons:
- The original Medprompt paper evaluated the performance of GPT-4. We observed that GPT-4o significantly outperforms GPT-4T and GPT-4 on several text benchmarks (https://openai.com/index/hello-gpt-4o/), indicating that Medprompt might have a smaller effect on a stronger model such as GPT-4o.
- We restricted our evaluation to 500 questions drawn from MedQA. The Medprompt paper evaluates other medical multiple-choice question datasets and the full version of MedQA. Evaluating GPT-4o on the full versions of the datasets might provide a better picture of overall performance.
Medprompt is an interesting framework for creating sophisticated feedback request pipelines, in particular for tailoring a generalist LLM to a specific domain without the need for fine-tuning. It also highlights the considerations to take into account when deciding between feedback request and fine-tuning for various use cases. It is important to explore how far feedback request can be taken to improve LLM performance, as it offers a cost-effective and resource-saving alternative to fine-tuning.
(1) Nori, H., Lee, Y.T., Zhang, S., Carignan, D., Edgar, R., Fusi, N., … & Horvitz, E. (2023). Can generalist baseline models outperform special-purpose fit models? A case study in medicine. arXiv preprint arXiv:2311.16452. (https://arxiv.org/abs/2311.16452)
(2) Wei, J., Wang, x., Schuurmans, D., Bosma, M., Xia, F., Chi, E., … & Zhou, D. (2022). Chain-of-thought prompting generates reasoning in broad linguistic models. Advances in neural information processing systems, 3524824–24837. (https://openreview.net/pdf?id=_VjQlMeSB_J)
(3) Gekhman, Z., Yona, G., Aharoni, R., Eyal, M., Feder, A., Reichart, R., & Herzig, J. (2024). Does fine-tuning LLMs on new knowledge promote hallucinations? arXiv preprint arXiv:2405.05904. (https://arxiv.org/abs/2405.05904)
(4) Singhal, K., Azizi, S., Tu, T., Mahdavi, S.S., Wei, J., Chung, H.W., … & Natarajan, V. (2023). Large language models encode clinical knowledge. Nature, 620(7972), 172–180. (https://www.nature.com/articles/s41586-023-06291-2)
(5) Singhal, K., Tu, T., Gottweis, J., Sayres, R., Wulczyn, E., Hou, L., … and Natarajan, V. (2023). Towards expert-level medical question answering with broad language models. arXiv preprint arXiv:2305.09617. (https://arxiv.org/abs/2305.09617)
(6) Jin, D., Pan, E., Oufattole, N., Weng, W.H., Fang, H., & Szolovits, P. (2021). What disease does this patient have? A large-scale open-domain question-answering dataset from medical examinations. Applied Science, eleven(14), 6421. (https://arxiv.org/abs/2009.13081) (The original dataset is published under an MIT license)