-
Reasoning On Graphs: Faithful and Interpretable Large Language Model Reasoning 논문 리뷰 (ICLR 2024)AI/NLP 2024. 4. 22. 10:49728x90
Reasoning On Graphs: Faithful and Interpretable Large Language Model Reasoning 논문 리뷰
(ICLR 2024)- 논문 읽은 후
Introduction
- 배경
- LLM의 Challenges : lack of knowledge & hallucination
- 이 원인은 reasoning process 에서 발생한 에러
- 이런 에러들이 legal judgement, medical diagnosis에 발생한다면 큰 문제
- LLM의 Challenges : lack of knowledge & hallucination
- Contribution
- 이를 해결하기 위해 Knowledge graph, 그래프 내에 엔티티 간 semantic info.를 담아주는 데이터 표현 방식을 사용하겠다
- “faithful and interpretable reasoning” 지식그래프를 활용할 시 얻게되는 두 이점을 나타낸 키워드라고 할 수 있습니다. 정말 답변에 참고한 데이터가 있는지, 그리고 그 데이터가 어떻게 형성되어 있는지를 그래프 형태로 직관적으로 볼 수 있다 라는 점
- 이를 해결하기 위해 Knowledge graph, 그래프 내에 엔티티 간 semantic info.를 담아주는 데이터 표현 방식을 사용하겠다
Approach :
출처
https://www.graphusergroup.com/24-2weeks-april-graph-omakase/
- 방법론
- 잘 가져오고, 가져온것이 최적인가를 검증한 뒤 검증된 Path 를 LLM에게 전달하는 것 이 방법론 핵심입니다. 이를 위해 ELBO(evidence lower bound) 를 활용해 planning 로 생성된 relation path와 retrieval-reasoning 로 생성된 reasoning path 간의 정보량을 계산해 그 간극을 최소화하는 방식으로 최적화가 이루어집니다.
- 이 때, 최적화의 대상은 LLM의 parameter입니다. 다시 말해서, 무엇을 가져오고 가져온 것을 기반으로 답을 만들 때 적합한지를 지속적으로 개선한다는 것이 본 아키텍쳐의 핵심이라 할 수 있습니다.
- 방금 말한 방법론은 크게 3가지 모듈로 나누어집니다. 1. Planning , 2. Retrieval , 3. Reasoning 각각 무엇을 가져올지 계획하는 단계 , 실제 그 무엇을 가져오는 단계 그리고 가져온 것을 Reasoning 형태로 LLM에게 제공하는 단계입니다.
- 1. Planning optimization
- planning은 relation Paths를 가져오는 KG(knowledge graph) 로부터 추출할 때 무엇이 좋을지 계획하는 단계입니다. 유저의 질문을 기반으로 답변에 도움이 될만한 요소를 Knowledge graph 에서 추출합니다. 논문에서 “distill knowledge from KGs” 라고 표현할 만큼 최적의 결과 값을 뽑아내는게 핵심입니다.
- 최적화 기준은 question 으로 부터 생성한 realtion path와 실제 path 가 answer 가 relation path와 연결되었는지를 Kullback–Leibler divergence(KLD , 쿨랙라이블러 발산) 을 활용하여 비교합니다. 간단히 말해서, 두 path 간 확률 분포 차를 계산하고 이를 최적화한다 라고 보시면 되겠습니다.
- 이 때, 확률 분포를 계산할 때 우선 정규 분포를 가정하고 question 과 상응하는 answer이 서브그래프 내 존재할 시 이를 반영하는 방식으로 진행됩니다. 값은 shortest path 의 역수를 반영합니다.
- 2. Retrieval-reasoning optimization
- Retrieval-reasoning은 planning 으로부터 생성된 여러 path 들 중 무슨 path 가 과연 의미할지를 연산하는 단계입니다. FiD 프레임워크를 활용합니다. FiD 프레임워크란 Fusion-in-Decoder (FiD)의 약자입니다.
- 다양한 passage 를 독립적으로 인코딩하고, 디코딩시에 독립적으로 인코딩 된 passage 값을 fusion 하기 위한 방법론으로써, planning 에서 생성된 여러개의 faithful relation 을 활용하기 위해 FiD 아이디어를 차용합니다.
- 1. Planning optimization
- 서두에 언급드렸다시피, 본 논문의 목적은 LLM의 파라미터를 학습하는게 목적입니다. 지금까지는 지식그래프에서 어떻게 가져오고 어떻게 주입하는지를 이야기했다면, 다음부터는 구체적으로 어떤 input형태로 LLM에게 주입되어 학습되는지에 대해 이야기합니다.
- 1. planning module
- Planning 은 relation path가 유의미한지를 LLM에게 토큰 형태로 주입해 학습하는 단계입니다. LLM 최적화를 위해 , relation 을 토큰 단위로 분절합니다. path / sep / path 3가지 토큰을 활용합니다.
- <path> r1 <sep> r2 <sep> … <sep> ri </path> 형태로 주입되며 특정 path 마다 어떤 relation 이 담겨있는지를 토큰형태로 LLM에게 주입하고 이를 파라미터 최적화에 활용합니다.
- 2. Retrieval - Reasoning module
- 주어진 질문 그리고 relation path 를 활용해 사전에 형성되어 있는 Knowledge graph로 부터 가져오고, 이를 종합해 reasoning paths 그리고 질문 으로 가공하여 LLM에게 주입합니다. 이 때, reasoning paths 의 여러 path들 중 어느 path가 중요한지를 판별하기 위해 reasoning module을 활용합니다.
- Reasoning module 은 path 결과물마다 answer 값이 정확한지 부정확한지를 확률 형태로 추출한 뒤 이를 기반으로 중요도를 판별하는 역할을 합니다.
- 1. planning module
Experiments & Results
- 실험 세팅
- Embedding , Retrieval , Semantic Parsing , LLMs , LLMs+KGs 총 5가지 방법론들을 활용해 실험
- Backbone LLM으로는 LLama-7b-chat 모델을 활용
- 결과
- RoG방법론이 LLaMA2 모델 뿐만아니라, 타 LLM 모델인 ChatGPT , Alpaca-7B , Flan-T5 에도 활용성이 높음
- 1. Ablation Study
- 2. Knowledge graph transferability
- 다른 Knowledge Graph에도 우리 성능 괜찮아 ~
- 2. Retrieval time with Knowledge graph hops
- Reasoning 을 통해 LLM의 답변 성능이 좋아진다 해도 사용자에게 전달되기까지 시간이 기존 대비 많이 소요된다면 실용성 떨어짐 -> Retrieval time 중요함 ㅇㅇ
Code
https://github.com/RManLuo/reasoning-on-graphs
일단 나도 비슷한 프로젝트를 했기 때문에, retrieval한 그래프를 어떻게 prompt에 넘겨줬는지 궁금했다
import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/..") import utils import random from typing import Callable class PromptBuilder(object): MCQ_INSTRUCTION = """Please answer the following questions. Please select the answers from the given choices and return the answer only.""" SAQ_INSTRUCTION = """Please answer the following questions. Please keep the answer as simple as possible and return all the possible answer as a list.""" MCQ_RULE_INSTRUCTION = """Based on the reasoning paths, please answer the given question. Please select the answers from the given choices and return the answers only.""" SAQ_RULE_INSTRUCTION = """Based on the reasoning paths, please answer the given question. Please keep the answer as simple as possible and return all the possible answers as a list.""" COT = """ Let's think it step by step.""" EXPLAIN = """ Please explain your answer.""" QUESTION = """Question:\n{question}""" GRAPH_CONTEXT = """Reasoning Paths:\n{context}\n\n""" CHOICES = """\nChoices:\n{choices}""" EACH_LINE = """ Please return each answer in a new line.""" def __init__(self, prompt_path, add_rule = False, use_true = False, cot = False, explain = False, use_random = False, each_line = False, maximun_token = 4096, tokenize: Callable = lambda x: len(x)): self.prompt_template = self._read_prompt_template(prompt_path) self.add_rule = add_rule self.use_true = use_true self.use_random = use_random self.cot = cot self.explain = explain self.maximun_token = maximun_token self.tokenize = tokenize self.each_line = each_line def _read_prompt_template(self, template_file): with open(template_file) as fin: prompt_template = f"""{fin.read()}""" return prompt_template def apply_rules(self, graph, rules, srouce_entities): results = [] for entity in srouce_entities: for rule in rules: res = utils.bfs_with_rule(graph, entity, rule) results.extend(res) return results def direct_answer(self, question_dict): graph = utils.build_graph(question_dict['graph']) entities = question_dict['q_entity'] rules = question_dict['predicted_paths'] prediction = [] if len(rules) > 0: reasoning_paths = self.apply_rules(graph, rules, entities) for p in reasoning_paths: if len(p) > 0: prediction.append(p[-1][-1]) return prediction def process_input(self, question_dict): ''' Take question as input and return the input with prompt ''' question = question_dict['question'] if not question.endswith('?'): question += '?' if self.add_rule: graph = utils.build_graph(question_dict['graph']) entities = question_dict['q_entity'] if self.use_true: rules = question_dict['ground_paths'] elif self.use_random: _, rules = utils.get_random_paths(entities, graph) else: rules = question_dict['predicted_paths'] if len(rules) > 0: reasoning_paths = self.apply_rules(graph, rules, entities) lists_of_paths = [utils.path_to_string(p) for p in reasoning_paths] # context = "\n".join([utils.path_to_string(p) for p in reasoning_paths]) else: lists_of_paths = [] #input += self.GRAPH_CONTEXT.format(context = context) input = self.QUESTION.format(question = question) # MCQ if len(question_dict['choices']) > 0: choices = '\n'.join(question_dict['choices']) input += self.CHOICES.format(choices = choices) if self.add_rule: instruction = self.MCQ_RULE_INSTRUCTION else: instruction = self.MCQ_INSTRUCTION # SAQ else: if self.add_rule: instruction = self.SAQ_RULE_INSTRUCTION else: instruction = self.SAQ_INSTRUCTION if self.cot: instruction += self.COT if self.explain: instruction += self.EXPLAIN if self.each_line: instruction += self.EACH_LINE if self.add_rule: other_prompt = self.prompt_template.format(instruction = instruction, input = self.GRAPH_CONTEXT.format(context = "") + input) context = self.check_prompt_length(other_prompt, lists_of_paths, self.maximun_token) input = self.GRAPH_CONTEXT.format(context = context) + input input = self.prompt_template.format(instruction = instruction, input = input) return input def check_prompt_length(self, prompt, list_of_paths, maximun_token): '''Check whether the input prompt is too long. If it is too long, remove the first path and check again.''' all_paths = "\n".join(list_of_paths) all_tokens = prompt + all_paths if self.tokenize(all_tokens) < maximun_token: return all_paths else: # Shuffle the paths random.shuffle(list_of_paths) new_list_of_paths = [] # check the length of the prompt for p in list_of_paths: tmp_all_paths = "\n".join(new_list_of_paths + [p]) tmp_all_tokens = prompt + tmp_all_paths if self.tokenize(tmp_all_tokens) > maximun_token: return "\n".join(new_list_of_paths) new_list_of_paths.append(p)
- 실제는 이렇게 된답니다
Paper & Ref.
https://arxiv.org/abs/2310.01061
https://www.graphusergroup.com/24-2weeks-april-graph-omakase/
728x90'AI > NLP' 카테고리의 다른 글