ローカルLLM:医師国家試験を解かせてみた(GPUなしCPUのみ)Gemma3 4B-it-q4_k_m
https://github.com/jungokasai/IgakuQA
から問題をひっぱってきまして、ローカルLLMでの医師国家試験の回答精度を見てみました。
といってもすべて解くには時間がかかるので、問題の中で2018年の問題だけ解いてみました。
中には、
診察器具の写真(①〜⑤)を別に示す。成人に対して鼻処置を行った上で、鼻腔から上咽頭、喉頭にかけて内視鏡検査を実施する際に使用する器具はどれか。2つ選べ。
といった問題があって画像問題でそもそも解けないものも含まれてはいます。
また、”2つ選べ”という指示に対して、一つしか選択しか返せないプログラムで、一つでも選べると正解というあまあま採点から。(最終的には改良している)
2つ選べとか3つ選べいう設問に対応できるようにしました。そうすると正解率がさがりまして、甘い採点だったのが厳しくなり、24%の正答率。
トータルの正答率は、38%
112-B の49問を解くのにかかった時間は、
742秒、12分22秒(Ryzen3 3200GのCPUのみでの計算)
422秒、7分2秒(Ryzen5 5600のCPUのみでの計算)
いやーこれほどCPUの差を見たのは久です。
2019年に発売されたCPU。別にマイコンのファームを書いたり、CADで基板を設計したりという用途がほとんどだったので、特に困っていなかった。Windows11にもできたし。
しかし、あまりに作業効率が悪いので急遽中古でAM4マザボに対応するCPUでお手頃なRyzen5 5600を中古で購入。1万円なり。ここから上のCPUだと2万円とか・・・まあ、いいんだけどそこまでするならGPU買うし・・。
======================================
import json
import time
import subprocess
import re
import os
# LLM呼び出し関数
def query_llm(prompt: str) -> str:
llama_run_path = "/root/models/gemma/llama.cpp/build/bin/llama-run"
model_path = "/root/models/gemma/gemma-4b-it-q4_k_m.gguf"
cmd = [llama_run_path, model_path]
try:
result = subprocess.run(cmd, input=prompt, capture_output=True, text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f"❌ LLM呼び出しエラー: {e}")
return "【エラー】Gemmaから回答を取得できませんでした"
# プロンプト構築
def build_prompt(question: str, choices: list) -> str:
choices_text = "\n".join([f"{chr(97+i)}. {c}" for i, c in enumerate(choices)])
if "2つ選べ" in question or "2つ選べ" in question:
instruction = "以下の選択肢から医学的根拠に基づいて最も適切なものを2つ選び、正解の小文字アルファベットを2つを, で区切りけ回答してください。最初に正解の小文字アルファベットを2つを回答し理由や根拠は示さず回答してください。"
elif "3つ選べ" in question or "3つ選べ" in question:
instruction = "以下の選択肢から医学的根拠に基づいて最も適切なものを3つ選び、正解の小文字アルファベットを3つを, で区切りけ回答してください。最初に正解の小文字アルファベットを3つを回答し理由や根拠は示さず回答してください。"
else:
instruction = "以下の選択肢から医学的根拠に基づいて最も適切なものを1つ選び、正解の小文字アルファベットを1つだけ回答してください。理由や根拠は示さず答えの小文字アルファベットのみを回答する"
return f"""Q: {question}
{instruction}
選択肢:
{choices_text}
【指示】
- 各選択肢の妥当性を医学的根拠に基づいて検討してください。
A:"""
# JSONL読み込み
def load_igakuqa_jsonl(path: str) -> list:
with open(path, "r", encoding="utf-8") as f:
return [json.loads(line.strip()) for line in f]
# JSONL保存
def save_igakuqa_jsonl(path: str, entries: list):
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# 統計保存
def save_summary(path: str, total: int, correct: int, elapsed_sec: float):
accuracy = (correct / total) * 100 if total > 0 else 0
summary = {
"解いた問題数": total,
"正解数": correct,
"正答率": f"{accuracy:.1f}%",
"所要時間(秒)": round(elapsed_sec, 1),
"所要時間(分)": round(elapsed_sec / 60, 1)
}
with open(path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# 回答記号抽出
def extract_choice_letters(answer_text: str) -> list:
candidates = set()
lines = answer_text.splitlines()
for line in lines:
if any(key in line for key in ["回答の選択肢", "回答は", "正解は", "選択肢は", "正答は"]):
candidates.update(re.findall(r'\b([a-e])\b', line))
candidates.update(re.findall(r"'([a-e])'", line))
if not candidates:
candidates.update(re.findall(r'\b([a-e])\b', answer_text))
return sorted(candidates) if candidates else ["?"]
# 🔁 ファイルループ処理 ["A", "B", "C", "D", "E", "F"]
file_suffixes = ["A", "B", "C", "D", "E", "F"]
base_path = "../data/2018"
for suffix in file_suffixes:
input_path = f"{base_path}/112-{suffix}.jsonl"
output_path = f"{base_path}/112-{suffix}-answered.jsonl"
summary_path = f"{base_path}/112-{suffix}-summary.json"
print(f"\n📁 ファイル 112-{suffix}.jsonl を処理中")
start_time = time.time()
qa_data = load_igakuqa_jsonl(input_path)
print(f"{len(qa_data)}問読み込みました")
total_questions = 0
correct_count = 0
for i, entry in enumerate(qa_data):
question = entry.get("problem_text", "").strip()
choices = entry.get("choices", [])
correct_answers = entry.get("answer", [])
if not question:
print(f"{i+1}問目は空の質問のためスキップ")
continue
print(f"\n{i+1}問目: {question}")
for idx, choice in enumerate(choices):
print(f" {chr(97+idx)}. {choice}")
prompt = build_prompt(question, choices)
model_output = query_llm(prompt)
print(model_output)
entry["model_answer"] = model_output
if not model_output or "エラー" in model_output:
print("⚠️ Gemmaから回答が得られませんでした。スキップします")
continue
model_choice = extract_choice_letters(model_output)
entry["model_choice"] = model_choice
is_correct = sorted(model_choice) == sorted(correct_answers)
result_text = "✅ 完全一致" if is_correct else "❌ 不一致"
print(f"{result_text}(モデル: {model_choice} / 正解: {correct_answers})")
total_questions += 1
if is_correct:
correct_count += 1
accuracy = (correct_count / total_questions) * 100 if total_questions > 0 else 0
print(f"📊 現在の統計 → 正解数: {correct_count} / {total_questions}(正解率: {accuracy:.1f}%)")
print("=" * 40)
time.sleep(0.5)
elapsed_sec = time.time() - start_time
save_igakuqa_jsonl(output_path, qa_data)
save_summary(summary_path, total_questions, correct_count, elapsed_sec)
print(f"✅ 回答付き保存 → {output_path}")
print(f"📁 統計保存 → {summary_path}")
コメント
コメントを投稿