ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Reasoning On Graphs: Faithful and Interpretable Large Language Model Reasoning 논문 리뷰 (ICLR 2024)
    AI/NLP 2024. 4. 22. 10:49
    728x90



     

     

    Reasoning On Graphs: Faithful and Interpretable Large Language Model Reasoning 논문 리뷰 
    (ICLR 2024)

     

     

     

    • 논문 읽은 후
      • RAG에서 Knowledge Graph를 활용한 논문으로, 이를 통해 LLM의 Reasoning 과정에 대한 해석력을 제공한다 
        • 방법이 크게 어렵진 않으나, appendix까지 쳬계적인 실험들과 case study들을 해서 잘 쓴 논문이라는 느낌이 확실히 들었다 역시 ICLR은 다르다.
      • RAG ... 이 분야 좀 재미있을지도? 연구할만한 주제가 많을 것 같다 
      • 또 읽어보고 싶은 논문 1, 논문 2

     

     

     

     


     

    Introduction

     

     

     

    •  배경
      • LLM의 Challenges : lack of knowledge & hallucination 
        • 이 원인은 reasoning process 에서 발생한 에러 
        • 이런 에러들이 legal judgement, medical diagnosis에 발생한다면 큰 문제
    •  Contribution
      • 이를 해결하기 위해 Knowledge graph, 그래프 내에 엔티티 간 semantic info.를 담아주는 데이터 표현 방식을 사용하겠다 
        • “faithful and interpretable reasoning” 지식그래프를 활용할 시 얻게되는 두 이점을 나타낸 키워드라고 할 수 있습니다. 정말 답변에 참고한 데이터가 있는지, 그리고 그 데이터가 어떻게 형성되어 있는지를 그래프 형태로 직관적으로 볼 수 있다 라는 점

     


     

    Approach : 

     

     

    출처 

    https://www.graphusergroup.com/24-2weeks-april-graph-omakase/

     

    24년 4월 둘째주 그래프 오마카세

    Graph Convolutional Networks using Heat Kernel for Semi-supervised Learning 배지훈 IJCAI 2020 link : https://arxiv.org/abs/2007.16002 code : https://github.com/Eilene/GraphHeat Keywords Semi-supervised Learning, Spectral graph convolution, Graph Heat Ke

    www.graphusergroup.com

     

     

    •  방법론
      • 잘 가져오고, 가져온것이 최적인가를 검증한 뒤 검증된 Path 를 LLM에게 전달하는 것 이 방법론 핵심입니다. 이를 위해 ELBO(evidence lower bound) 를 활용해 planning 로 생성된 relation path와 retrieval-reasoning 로 생성된 reasoning path 간의 정보량을 계산해 그 간극을 최소화하는 방식으로 최적화가 이루어집니다.
      • 이 때, 최적화의 대상은 LLM의 parameter입니다. 다시 말해서, 무엇을 가져오고 가져온 것을 기반으로 답을 만들 때 적합한지를 지속적으로 개선한다는 것이 본 아키텍쳐의 핵심이라 할 수 있습니다.

     

     

     

     

     

    • 방금 말한 방법론은 크게 3가지 모듈로 나누어집니다. 1. Planning , 2. Retrieval , 3. Reasoning 각각 무엇을 가져올지 계획하는 단계 , 실제 그 무엇을 가져오는 단계 그리고 가져온 것을 Reasoning 형태로 LLM에게 제공하는 단계입니다.
      • 1. Planning optimization
        • planning은 relation Paths를 가져오는 KG(knowledge graph) 로부터 추출할 때 무엇이 좋을지 계획하는 단계입니다. 유저의 질문을 기반으로 답변에 도움이 될만한 요소를 Knowledge graph 에서 추출합니다. 논문에서 “distill knowledge from KGs” 라고 표현할 만큼 최적의 결과 값을 뽑아내는게 핵심입니다.
        • 최적화 기준은 question 으로 부터 생성한 realtion path와 실제 path 가 answer 가 relation path와 연결되었는지를 Kullback–Leibler divergence(KLD , 쿨랙라이블러 발산) 을 활용하여 비교합니다. 간단히 말해서, 두 path 간 확률 분포 차를 계산하고 이를 최적화한다 라고 보시면 되겠습니다.
        • 이 때, 확률 분포를 계산할 때 우선 정규 분포를 가정하고 question 과 상응하는 answer이 서브그래프 내 존재할 시 이를 반영하는 방식으로 진행됩니다. 값은 shortest path 의 역수를 반영합니다.
      • 2. Retrieval-reasoning optimization
        • Retrieval-reasoning은 planning 으로부터 생성된 여러 path 들 중 무슨 path 가 과연 의미할지를 연산하는 단계입니다. FiD 프레임워크를 활용합니다. FiD 프레임워크란 Fusion-in-Decoder (FiD)의 약자입니다.
        • 다양한 passage 를 독립적으로 인코딩하고, 디코딩시에 독립적으로 인코딩 된 passage 값을 fusion 하기 위한 방법론으로써, planning 에서 생성된 여러개의 faithful relation 을 활용하기 위해 FiD 아이디어를 차용합니다.

     

     

     

    • 서두에 언급드렸다시피, 본 논문의 목적은 LLM의 파라미터를 학습하는게 목적입니다. 지금까지는 지식그래프에서 어떻게 가져오고 어떻게 주입하는지를 이야기했다면, 다음부터는 구체적으로 어떤 input형태로 LLM에게 주입되어 학습되는지에 대해 이야기합니다.
      • 1. planning module
        • Planning 은 relation path가 유의미한지를 LLM에게 토큰 형태로 주입해 학습하는 단계입니다. LLM 최적화를 위해 , relation 을 토큰 단위로 분절합니다. path / sep / path 3가지 토큰을 활용합니다.
        • <path> r1 <sep> r2 <sep> … <sep> ri </path> 형태로 주입되며 특정 path 마다 어떤 relation 이 담겨있는지를 토큰형태로 LLM에게 주입하고 이를 파라미터 최적화에 활용합니다.
      • 2. Retrieval - Reasoning module
        • 주어진 질문 그리고 relation path 를 활용해 사전에 형성되어 있는 Knowledge graph로 부터 가져오고, 이를 종합해 reasoning paths 그리고 질문 으로 가공하여 LLM에게 주입합니다. 이 때, reasoning paths 의 여러 path들 중 어느 path가 중요한지를 판별하기 위해 reasoning module을 활용합니다.
        • Reasoning module 은 path 결과물마다 answer 값이 정확한지 부정확한지를 확률 형태로 추출한 뒤 이를 기반으로 중요도를 판별하는 역할을 합니다.

     

     

    Experiments & Results 

     

     

     

    • 실험 세팅
      • Embedding , Retrieval , Semantic Parsing , LLMs , LLMs+KGs 총 5가지 방법론들을 활용해 실험
      • Backbone LLM으로는 LLama-7b-chat 모델을 활용
    • 결과 
      • RoG방법론이 LLaMA2 모델 뿐만아니라, 타 LLM 모델인 ChatGPT , Alpaca-7B , Flan-T5 에도 활용성이 높음

     

     

     

    • 1. Ablation Study

     

     

     

     

    • 2. Knowledge graph transferability
      • 다른 Knowledge Graph에도 우리 성능 괜찮아 ~

     

     

     

    • 2.  Retrieval time with Knowledge graph hops
      • Reasoning 을 통해 LLM의 답변 성능이 좋아진다 해도 사용자에게 전달되기까지 시간이 기존 대비 많이 소요된다면 실용성 떨어짐 -> Retrieval time 중요함 ㅇㅇ 

     

     

     

     

     

     

    Code

     

     

    https://github.com/RManLuo/reasoning-on-graphs

     

    GitHub - RManLuo/reasoning-on-graphs: Official Implementation of ICLR 2024 paper: "Reasoning on Graphs: Faithful and Interpretab

    Official Implementation of ICLR 2024 paper: "Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning" - RManLuo/reasoning-on-graphs

    github.com

     

     

    일단 나도 비슷한 프로젝트를 했기 때문에, retrieval한 그래프를 어떻게 prompt에 넘겨줬는지 궁금했다 

     

    import sys
    import os
    sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/..")
    import utils
    import random
    from typing import Callable
    
    class PromptBuilder(object):
        MCQ_INSTRUCTION = """Please answer the following questions. Please select the answers from the given choices and return the answer only."""
        SAQ_INSTRUCTION = """Please answer the following questions. Please keep the answer as simple as possible and return all the possible answer as a list."""
        MCQ_RULE_INSTRUCTION = """Based on the reasoning paths, please answer the given question. Please select the answers from the given choices and return the answers only."""
        SAQ_RULE_INSTRUCTION = """Based on the reasoning paths, please answer the given question. Please keep the answer as simple as possible and return all the possible answers as a list."""
        COT = """ Let's think it step by step."""
        EXPLAIN = """ Please explain your answer."""
        QUESTION = """Question:\n{question}"""
        GRAPH_CONTEXT = """Reasoning Paths:\n{context}\n\n"""
        CHOICES = """\nChoices:\n{choices}"""
        EACH_LINE = """ Please return each answer in a new line."""
        def __init__(self, prompt_path, add_rule = False, use_true = False, cot = False, explain = False, use_random = False, each_line = False, maximun_token = 4096, tokenize: Callable = lambda x: len(x)):
            self.prompt_template = self._read_prompt_template(prompt_path)
            self.add_rule = add_rule
            self.use_true = use_true
            self.use_random = use_random
            self.cot = cot
            self.explain = explain
            self.maximun_token = maximun_token
            self.tokenize = tokenize
            self.each_line = each_line
            
        def _read_prompt_template(self, template_file):
            with open(template_file) as fin:
                prompt_template = f"""{fin.read()}"""
            return prompt_template
        
        def apply_rules(self, graph, rules, srouce_entities):
            results = []
            for entity in srouce_entities:
                for rule in rules:
                    res = utils.bfs_with_rule(graph, entity, rule)
                    results.extend(res)
            return results
        
        def direct_answer(self, question_dict):
            graph = utils.build_graph(question_dict['graph'])
            entities = question_dict['q_entity']
            rules = question_dict['predicted_paths']
            prediction = []
            if len(rules) > 0:
                reasoning_paths = self.apply_rules(graph, rules, entities)
                for p in reasoning_paths:
                    if len(p) > 0:
                        prediction.append(p[-1][-1])
            return prediction
        
        
        def process_input(self, question_dict):
            '''
            Take question as input and return the input with prompt
            '''
            question = question_dict['question']
            
            if not question.endswith('?'):
                question += '?'
            
    
            if self.add_rule:
                graph = utils.build_graph(question_dict['graph'])
                entities = question_dict['q_entity']
                if self.use_true:
                    rules = question_dict['ground_paths']
                elif self.use_random:
                    _, rules = utils.get_random_paths(entities, graph)
                else:
                    rules = question_dict['predicted_paths']
                if len(rules) > 0:
                    reasoning_paths = self.apply_rules(graph, rules, entities)
                    lists_of_paths = [utils.path_to_string(p) for p in reasoning_paths]
                    # context = "\n".join([utils.path_to_string(p) for p in reasoning_paths])
                else:
                    lists_of_paths = []
                #input += self.GRAPH_CONTEXT.format(context = context)
                
            input = self.QUESTION.format(question = question)
            # MCQ
            if len(question_dict['choices']) > 0:
                choices = '\n'.join(question_dict['choices'])
                input += self.CHOICES.format(choices = choices)
                if self.add_rule:
                    instruction = self.MCQ_RULE_INSTRUCTION
                else:
                    instruction = self.MCQ_INSTRUCTION
            # SAQ
            else:
                if self.add_rule:
                    instruction = self.SAQ_RULE_INSTRUCTION
                else:
                    instruction = self.SAQ_INSTRUCTION
            
            if self.cot:
                instruction += self.COT
            
            if self.explain:
                instruction += self.EXPLAIN
                
            if self.each_line:
                instruction += self.EACH_LINE
            
            if self.add_rule:
                other_prompt = self.prompt_template.format(instruction = instruction, input = self.GRAPH_CONTEXT.format(context = "") + input)
                context = self.check_prompt_length(other_prompt, lists_of_paths, self.maximun_token)
                
                input = self.GRAPH_CONTEXT.format(context = context) + input
            
            input = self.prompt_template.format(instruction = instruction, input = input)
                
            return input
        
        def check_prompt_length(self, prompt, list_of_paths, maximun_token):
            '''Check whether the input prompt is too long. If it is too long, remove the first path and check again.'''
            all_paths = "\n".join(list_of_paths)
            all_tokens = prompt + all_paths
            if self.tokenize(all_tokens) < maximun_token:
                return all_paths
            else:
                # Shuffle the paths
                random.shuffle(list_of_paths)
                new_list_of_paths = []
                # check the length of the prompt
                for p in list_of_paths:
                    tmp_all_paths = "\n".join(new_list_of_paths + [p])
                    tmp_all_tokens = prompt + tmp_all_paths
                    if self.tokenize(tmp_all_tokens) > maximun_token:
                        return "\n".join(new_list_of_paths)
                    new_list_of_paths.append(p)

     

     

     

     

    • 실제는 이렇게 된답니다 

     

     

     

     

     

     

     

     


     

    Paper & Ref. 

     

    https://arxiv.org/abs/2310.01061

     

    Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning

    Large language models (LLMs) have demonstrated impressive reasoning abilities in complex tasks. However, they lack up-to-date knowledge and experience hallucinations during reasoning, which can lead to incorrect reasoning processes and diminish their perfo

    arxiv.org

     

    https://www.graphusergroup.com/24-2weeks-april-graph-omakase/

     

    24년 4월 둘째주 그래프 오마카세

    Graph Convolutional Networks using Heat Kernel for Semi-supervised Learning 배지훈 IJCAI 2020 link : https://arxiv.org/abs/2007.16002 code : https://github.com/Eilene/GraphHeat Keywords Semi-supervised Learning, Spectral graph convolution, Graph Heat Ke

    www.graphusergroup.com

     

    728x90
Designed by Tistory.