Date:     Updated:

카테고리:

태그: , , , ,


💡 교내 학회 NLP 분반에서 학습한 내용을 정리한 포스팅입니다.


본 과제는 영화 대사를 학습데이터로 사용하여 transformer로 챗봇을 구현하는 과제이다.


1. 라이브러리 설치, 경로 설정

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F
corpus_movie_conv = '/content/drive/MyDrive/kubig/chatbot/movie_conversations.txt'
corpus_movie_lines = '/content/drive/MyDrive/kubig/chatbot/movie_lines.txt'
max_len = 25 #대사당 최대 단어 개수
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()

with open(corpus_movie_lines, 'r',encoding= 'unicode_escape') as l:
    lines = l.readlines()


2. txt 파일 불러오고 전처리해주기

movie_lines.txt 파일에는 영화 대사들이 주어져있다. 총 5개의 컬럼이 +++$+++라는 구분자로 나뉘어 있는데, 이 중에서 사용할 것은 대사(line)의 id인 첫번째 컬럼과 대사 text인 마지막 컬럼이다. 아래 코드를 보면 key를 대사 id로 하고 value를 대사 text로 하는 딕셔너리를 만들었다.

참고) 아래 사진은 movie_lines.txt 파일의 캡처본으로, 다음을 의미한다

  • 1열: 대사 id
  • 2열: 대사를 한 캐릭터 id
  • 3열: 영화 id
  • 4열: 캐릭터 이름
  • 5열: 대사 내용 text

movie_lines_capture.png

lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1] #object[0]이 대사 id, object[-1]이 대사 내용 text
print("총 대사(line)의 개수:",len(lines_dic))
총 대사(line)의 개수: 304713
## 구두점, 기호를 없애주고 소문자로 바꿔주는 함수
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()


3. 대사 쌍(pairs) 만들어주기

챗봇은 질문과 답변으로 구성된다. 즉 학습데이터 역시 질문 문장과 답변 문장의 쌍(pair)으로 구성이 되어야 한다. 아까 불러온 대사(line)들을 pair로 짝지어주기 위한 정보를 movie_converstion.txt 파일에서 불러올 것이다.

또한 트랜스포머에 입력 데이터로 넣기 위해 단어(토큰) 단위로 split을 해줘야 한다.

movie_conversations.txt 파일은 아래 사진처럼 구성되어 있다.

  • 1열: 대화에 등장하는 첫번째 캐릭터의 id
  • 2열: 대화에 등장하는 두번째 캐릭터의 id
  • 3열: 영화 id
  • 4열: 대화가 구성되는 line들의 리스트

movie.conversations.txt_capture.png

위 사진처럼 L194, L195, L196, L197이 하나의 conversation으로 묶여있다고 가정하자. 그럼 아래 함수는 L194, L195를 하나의 pair로, L195, L196을 하나의 pair로, L196, L197을 하나의 pair로 만들어준다.

즉 pairs라는 리스트를 보면 아래처럼 구성될 것이다.

  • pairs[0]: [[L194 대사의 단어들 토큰 리스트], [L195 대사의 단어들 토큰 리스트]]
  • pairs[1]: [[L195 대사의 단어들 토큰 리스트], [L196 대사의 단어들 토큰 리스트]]
  • pairs[2]: [[L196 대사의 단어들 토큰 리스트], [L197 대사의 단어들 토큰 리스트]]
  • pairs[3]: [[L198 대사의 단어들 토큰 리스트], [L199 대사의 단어들 토큰 리스트]]
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1]) # 각 행의 line 리스트를 ids에 저장
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dic[ids[i]].strip())      
        second = remove_punc(lines_dic[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len]) # 띄어쓰기 단위로 split해주고 max_len(25)개만큼의 단어만 담을 것임
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)
print(lines_dic['L194']) # 질문
print(lines_dic['L195']) # 답변
Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.

Well, I thought we'd start with pronunciation, if that's okay with you.
print(pairs[0][0]) # 질문
print(pairs[0][1]) # 답변
['can', 'we', 'make', 'this', 'quick', 'roxanne', 'korrine', 'and', 'andrew', 'barrett', 'are', 'having', 'an', 'incredibly', 'horrendous', 'public', 'break', 'up', 'on', 'the', 'quad', 'again']
['well', 'i', 'thought', 'wed', 'start', 'with', 'pronunciation', 'if', 'thats', 'okay', 'with', 'you']
print("만들어진 문장 pairs의 개수:",len(pairs))
만들어진 문장 pairs의 개수: 221616

Counter.update()는 추가된 리스트를 누적하여 카운팅해준다.

word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])


4. 정수 인코딩하여 시퀀스 형태로 변형

counting 결과 빈도수가 5를 넘는 단어들만 word_map 딕셔너리에 정수 인코딩하여 담아준다. 이 때 네 가지 스페셜 토큰들도 정수 인코딩을 해준다.

min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1 #18240
word_map['<start>'] = len(word_map) + 1 #18241
word_map['<end>'] = len(word_map) + 1 #18242
word_map['<pad>'] = 0
print("Total words are: {}".format(len(word_map)))
Total words are: 18243
# 생성한 word_map을 dump해준다.
with open('WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

정수 인코딩 정보가 들어있는 word_map을 토대로 question 대사와 reply 문장을 정수 인코딩하여 시퀀스 형태로 만들 것이다.

question은 대사를 정수인코딩 하고 뒤에 남는 부분을 패딩하는 반면, reply는 시작과 끝에 <start>, <end> 토큰을 넣고 뒤에 남는 부분을 패딩한다.

def encode_question(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
    [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

인코딩된 pairs가 paris_encoded 리스트에 담긴다.

pairs_encoded = []
for pair in pairs:
    qus = encode_question(pair[0], word_map)
    ans = encode_reply(pair[1], word_map)
    pairs_encoded.append([qus, ans])
# 생성한 pairs_encoded을 dump해준다.
with open('pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)


5. dataset 만들기

class Dataset(Dataset):

    def __init__(self):

        self.pairs = json.load(open('pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    def __len__(self):
        return self.dataset_size
train_loader = torch.utils.data.DataLoader(Dataset(),
                                           batch_size = 100, 
                                           shuffle=True, 
                                           pin_memory=True)

데이터셋 준비가 완료되었으니 본격적으로 transformer 구조를 구현해보도록 한다. 대략적인 구조는 아래 사진과 같다. 크게 인코더와 디코더로 나뉘고, 각각은 2개의 서브층과 3개의 서브층을 가지고 있다.

인코더, 디코더 모두 단어 임베딩을 시킨 후 positional encoding까지 시킨 것을 입력값으로 받고 있다.

transformer_attention_overview.png


6. Masking

아래 함수는 참고하면 안되는 요소들을 가려주는 mask를 구현한 것이다. 후술하겠지만, 트랜스포머 구조에는 3개의 멀티헤드 어텐션이 있다.

  • 인코더의 첫번째 서브층: padding mask
  • 디코더의 첫번째 서브층: look-ahead mask
  • 디코더의 두번째 서브층: padding mask

각 멀티헤드 어텐션은 위 같은 마스크를 함수 인자로 받는다.

def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask


7. 포지셔널 인코딩(Positional Encoding)

트랜스포머는 RNN과 달리 각 단어의 위치 정보를 알 수 없다는 단점을 지닌다. 때문에 Positional encoding을 통해 임베딩 벡터를 만들 때 위치 정보를 더하는 작업을 거쳐줘야 한다. 그리하여 같은 단어라고 하더라도 문장 내 위치에 따라 임베딩 벡터값은 달라지게 된다.

positional encoding.png

class Embeddings(nn.Module):
    """
    Implements embeddings of the words and adds their positional encodings. 
    """
    def __init__(self, vocab_size, d_model, max_len = 50):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model)
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):   # 각 단어의 위치치
            for i in range(0, d_model, 2):   # 한 단어 내에서의 dimension
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model))) # 짝수의 경우 sin함수 사용
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) # 홀수의 경우 cos함수 사용
        pe = pe.unsqueeze(0)   # 첫번째 차원에 1(batch size)을 추가
        return pe
        
    def forward(self, encoded_words):
        embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
        # 임베딩 벡터의 행렬과 포지셔널 인코딩 행렬을 덧셈 연산해준다.    
        embedding += self.pe[:, :embedding.size(1)]   # pe will automatically be expanded with the same batch size as encoded_words
        embedding = self.dropout(embedding)
        return embedding


8. 멀티헤드 어텐션(Multi-Head Attention)

MultiHeadAttention은 어텐션을 병렬로 수행하는 것이다. 한 번의 어텐션보다 여러번의 어텐션을 병렬로 수행하는 것이 더 효과적이기 때문에 8개의 병렬 어텐션을 주로 사용한다. 몇 개를 사용할지는 아래 코드에서 heads라는 파라미터로 결정한다.

encoder self attention.png

위 그림을 보면 self-attention임을 확인할 수 있다. 어텐션을 자기 자신에게 수행한다는 의미이다. 즉 Q(쿼리), K(키), V(밸류)가 모두 ‘입력 문장의 모든 단어 벡터’들이 된다.

scaled dot product.png

아래 코드의 scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1)) 부분부터는 위 scaled dot product를 구현하는 과정이다.

class MultiHeadAttention(nn.Module):
    
    def __init__(self, heads, d_model):
        
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads # 단어벡터의 차원을 heads로 나눠주어 병렬 작업이 가능하게 한다.
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, 512)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, 512)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        # permute 함수는 차원들을 맞교환한다. 행렬곱을 위하여 transpose하는 과정이라 생각하면 된다.
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # scaled dot product 구현 과정!
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))

        # <pad> 토큰이 있는 경우 아주 작은 음수(-1e9)를 곱하여 소프트맥스에 의해 0이 되게끔한다.
        # 즉 단어 간 유사도를 구할 때 <pad> 토큰은 반영이 되지 않게끔 한다.
        scores = scores.masked_fill(mask == 0, -1e9)    # (batch_size, h, max_len, max_len)
        weights = F.softmax(scores, dim = -1)           # (batch_size, h, max_len, max_len)
        weights = self.dropout(weights)
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # 마지막으로 병렬로 만들어진 모든 어텐션 헤드를 concat해준다.
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        # (batch_size, max_len, h * d_k)
        interacted = self.concat(context)
        return interacted 

9. FeedForward

ffnn.png

멀티 헤드 어텐션을 통과한 결과물은 feedforward의 입력값으로 사용된다.

class FeedForward(nn.Module):

    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

10. 인코더(Encoder) 전체 구조

layernorm.png

두 개의 서브층인 멀티헤드 셀프어텐션과 FFNN을 지날 때 각각 Add&Norm을 거치고 있다. 이는 residual connection과 layer normalization을 수행해준다는 의미이다.

  • residual connection: 서브층의 input + 서브층의 output
  • layer normalization: 텐서의 마지막 차원에 대해 평균과 분산을 구하여 어떤 수식을 통해 값을 정규화
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # interacted(멀티헤드의 output)와 embeddings(멀티헤드의 input)을 더해주고 이를 layernorm(층정규화)시킴
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        # feed_forward_oout(feedforward의 output)와 interacted(멀티헤드의 output이자 feedforward의 input)을 더해주고 
        # 이를 layernorm(층정규화)시킴
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

11. 디코더(Decoder) 전체 구조

decoder.png

위 사진은 디코더를 시각화한 것이다. 첫번째 서브층은 멀티헤드 셀프어텐션인 반면, 두번째 서브층은 멀티헤드 어텐션이다.

첫번째 서브층에서는 look-ahead mask를 해야하는데, RNN이 매 시점마다 단어를 입력받는 반면, 어텐션은 한꺼번에 단어 행렬을 입력받으므로 미래 시점의 단어까지 참고할 수 있는 상황이 발생해버린다. 그래서 미래 시점의 단어를 가려주는 look-ahead mask를 실시해주는 것이다. 첫 부분에서 언급한 바 있다.

두번째 서브층이 멀티헤드 어텐션이긴 하나, 셀프 어텐션이 아니다. 즉, Query가 디코더의 행렬인 반면, Key, Value는 인코더의 행렬이다. 위 그림의 화살표 모양에서도 이 점을 확인할 수 있다.

class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask)) # look-ahead mask
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

12. 트랜스포머(Transformer) 구현

이제 앞서 만든 인코더와 디코더를 조립하여 트랜스포머를 구현할 차례이다. num_layers에 인코더, 디코더 층의 개수를 설정해줄 수 있다. 이 층 개수만큼 ModuleList를 만들어주어 반복문을 순회하며 인코더, 디코더를 통과하는 원리이다.

마지막 출력층은 vocab_size를 설정하여 전체 단어 중에 예측하는 다중 분류로 softmax시키게 된다.

class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)]) #num_layers 개수만큼 층 생성
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)]) #num_layers 개수만큼 층 생성
        self.logit = nn.Linear(d_model, self.vocab_size)
        
    def encode(self, src_words, src_mask):
        src_embeddings = self.embed(src_words)
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self, target_words, target_mask, src_embeddings, src_mask):
        tgt_embeddings = self.embed(target_words)
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out

13. Learning Rate Sheduler와 손실함수 정의

class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()       
class LossWithLS(nn.Module):

    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        mask = mask.float()
        mask = mask.view(-1)       # (batch_size * max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

14. Train and Evaluate

d_model = 512 # 단어 임베딩딩 벡터의 차원원
heads = 8 # 멀티헤드 어텐션 병렬 헤드 개수
num_layers = 3 # 인코더, 디코드의 층 개수수
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 5

with open('WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)
    
transformer = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)
def train(train_loader, transformer, criterion, epoch):
    
    transformer.train()
    sum_loss = 0
    count = 0

    for i, (question, reply) in enumerate(train_loader):
        
        samples = question.shape[0]

        # Move to device
        question = question.to(device)
        reply = reply.to(device)

        # Prepare Target Data
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # 앞서 만든 마스크 함수를 사용한다.
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)

        # 앞서 정의한 trasnformer 클래스스로 결과를 출력한다.
        out = transformer(question, question_mask, reply_input, reply_input_mask)

        # Compute the loss
        loss = criterion(out, reply_target, reply_target_mask)
        
        # 역전파를 수행한다.
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        count += samples
        
        if i % 100 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))
def evaluate(transformer, question, question_mask, max_len, word_map):
    """
    Performs Greedy Decoding with a batch size of 1
    """
    rev_word_map = {v: k for k, v in word_map.items()}
    transformer.eval()
    start_token = word_map['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device)
    
    for step in range(max_len - 1):
        size = words.shape[1]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        predictions = transformer.logit(decoded[:, -1])
        _, next_word = torch.max(predictions, dim = 1)
        next_word = next_word.item()
        if next_word == word_map['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1)   # (1,step+2)
        
    # Construct Sentence
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
        
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    
    return sentence
for epoch in range(epochs):
    
    train(train_loader, transformer, criterion, epoch)
    
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')
Epoch [0][0/2217]	Loss: 8.655
Epoch [0][100/2217]	Loss: 7.907
Epoch [0][200/2217]	Loss: 7.196
Epoch [0][300/2217]	Loss: 6.666
Epoch [0][400/2217]	Loss: 6.308
Epoch [0][500/2217]	Loss: 6.061
Epoch [0][600/2217]	Loss: 5.877
Epoch [0][700/2217]	Loss: 5.728
Epoch [0][800/2217]	Loss: 5.612
Epoch [0][900/2217]	Loss: 5.514
Epoch [0][1000/2217]	Loss: 5.429
Epoch [0][1100/2217]	Loss: 5.358
Epoch [0][1200/2217]	Loss: 5.296
Epoch [0][1300/2217]	Loss: 5.241
Epoch [0][1400/2217]	Loss: 5.193
Epoch [0][1500/2217]	Loss: 5.151
Epoch [0][1600/2217]	Loss: 5.113
Epoch [0][1700/2217]	Loss: 5.079
Epoch [0][1800/2217]	Loss: 5.048
Epoch [0][1900/2217]	Loss: 5.020
Epoch [0][2000/2217]	Loss: 4.994
Epoch [0][2100/2217]	Loss: 4.970
Epoch [0][2200/2217]	Loss: 4.947
Epoch [1][0/2217]	Loss: 4.437
Epoch [1][100/2217]	Loss: 4.425
Epoch [1][200/2217]	Loss: 4.424
Epoch [1][300/2217]	Loss: 4.420
Epoch [1][400/2217]	Loss: 4.418
Epoch [1][500/2217]	Loss: 4.418
Epoch [1][600/2217]	Loss: 4.419
Epoch [1][700/2217]	Loss: 4.415
Epoch [1][800/2217]	Loss: 4.414
Epoch [1][900/2217]	Loss: 4.412
Epoch [1][1000/2217]	Loss: 4.411
Epoch [1][1100/2217]	Loss: 4.411
Epoch [1][1200/2217]	Loss: 4.409
Epoch [1][1300/2217]	Loss: 4.409
Epoch [1][1400/2217]	Loss: 4.409
Epoch [1][1500/2217]	Loss: 4.409
Epoch [1][1600/2217]	Loss: 4.409
Epoch [1][1700/2217]	Loss: 4.409
Epoch [1][1800/2217]	Loss: 4.409
Epoch [1][1900/2217]	Loss: 4.409
Epoch [1][2000/2217]	Loss: 4.408
Epoch [1][2100/2217]	Loss: 4.407
Epoch [1][2200/2217]	Loss: 4.405
Epoch [2][0/2217]	Loss: 4.396
Epoch [2][100/2217]	Loss: 4.296
Epoch [2][200/2217]	Loss: 4.303
Epoch [2][300/2217]	Loss: 4.302
Epoch [2][400/2217]	Loss: 4.300
Epoch [2][500/2217]	Loss: 4.303
Epoch [2][600/2217]	Loss: 4.307
Epoch [2][700/2217]	Loss: 4.308
Epoch [2][800/2217]	Loss: 4.310
Epoch [2][900/2217]	Loss: 4.311
Epoch [2][1000/2217]	Loss: 4.311
Epoch [2][1100/2217]	Loss: 4.311
Epoch [2][1200/2217]	Loss: 4.311
Epoch [2][1300/2217]	Loss: 4.310
Epoch [2][1400/2217]	Loss: 4.310
Epoch [2][1500/2217]	Loss: 4.309
Epoch [2][1600/2217]	Loss: 4.309
Epoch [2][1700/2217]	Loss: 4.308
Epoch [2][1800/2217]	Loss: 4.306
Epoch [2][1900/2217]	Loss: 4.306
Epoch [2][2000/2217]	Loss: 4.305
Epoch [2][2100/2217]	Loss: 4.303
Epoch [2][2200/2217]	Loss: 4.301
Epoch [3][0/2217]	Loss: 4.034
Epoch [3][100/2217]	Loss: 4.184
Epoch [3][200/2217]	Loss: 4.190
Epoch [3][300/2217]	Loss: 4.193
Epoch [3][400/2217]	Loss: 4.192
Epoch [3][500/2217]	Loss: 4.193
Epoch [3][600/2217]	Loss: 4.192
Epoch [3][700/2217]	Loss: 4.194
Epoch [3][800/2217]	Loss: 4.194
Epoch [3][900/2217]	Loss: 4.195
Epoch [3][1000/2217]	Loss: 4.199
Epoch [3][1100/2217]	Loss: 4.201
Epoch [3][1200/2217]	Loss: 4.202
Epoch [3][1300/2217]	Loss: 4.202
Epoch [3][1400/2217]	Loss: 4.203
Epoch [3][1500/2217]	Loss: 4.203
Epoch [3][1600/2217]	Loss: 4.204
Epoch [3][1700/2217]	Loss: 4.204
Epoch [3][1800/2217]	Loss: 4.204
Epoch [3][1900/2217]	Loss: 4.204
Epoch [3][2000/2217]	Loss: 4.204
Epoch [3][2100/2217]	Loss: 4.204
Epoch [3][2200/2217]	Loss: 4.204
Epoch [4][0/2217]	Loss: 4.173
Epoch [4][100/2217]	Loss: 4.101
Epoch [4][200/2217]	Loss: 4.103
Epoch [4][300/2217]	Loss: 4.105
Epoch [4][400/2217]	Loss: 4.110
Epoch [4][500/2217]	Loss: 4.114
Epoch [4][600/2217]	Loss: 4.117
Epoch [4][700/2217]	Loss: 4.117
Epoch [4][800/2217]	Loss: 4.119
Epoch [4][900/2217]	Loss: 4.119
Epoch [4][1000/2217]	Loss: 4.122
Epoch [4][1100/2217]	Loss: 4.124
Epoch [4][1200/2217]	Loss: 4.125
Epoch [4][1300/2217]	Loss: 4.126
Epoch [4][1400/2217]	Loss: 4.127
Epoch [4][1500/2217]	Loss: 4.127
Epoch [4][1600/2217]	Loss: 4.128
Epoch [4][1700/2217]	Loss: 4.129
Epoch [4][1800/2217]	Loss: 4.129
Epoch [4][1900/2217]	Loss: 4.130
Epoch [4][2000/2217]	Loss: 4.131
Epoch [4][2100/2217]	Loss: 4.132
Epoch [4][2200/2217]	Loss: 4.132
checkpoint = torch.load('/content/checkpoint_1.pth.tar')
transformer = checkpoint['transformer']
while(1):
    question = input("Question: ") 
    if question == 'quit':
        break
    max_len = input("Maximum Reply Length: ")
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)  
    sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
    print(sentence)
i dont know
i dont know
i dont know