🚴‍♂️
TIL
  • MAIN
  • : TIL?
  • : WIL
  • : Plan
  • : Retrospective
    • 21Y
      • Wait a moment!
      • 9M 2W
      • 9M1W
      • 8M4W
      • 8M3W
      • 8M2W
      • 8M1W
      • 7M4W
      • 7M3W
      • 7M2W
      • 7M1W
      • 6M5W
      • 1H
    • 새사람 되기 프로젝트
      • 2회차
      • 1회차
  • TIL : ML
    • Paper Analysis
      • BERT
      • Transformer
    • Boostcamp 2st
      • [S]Data Viz
        • (4-3) Seaborn 심화
        • (4-2) Seaborn 기초
        • (4-1) Seaborn 소개
        • (3-4) More Tips
        • (3-3) Facet 사용하기
        • (3-2) Color 사용하기
        • (3-1) Text 사용하기
        • (2-3) Scatter Plot 사용하기
        • (2-2) Line Plot 사용하기
        • (2-1) Bar Plot 사용하기
        • (1-3) Python과 Matplotlib
        • (1-2) 시각화의 요소
        • (1-1) Welcome to Visualization (OT)
      • [P]MRC
        • (2강) Extraction-based MRC
        • (1강) MRC Intro & Python Basics
      • [P]KLUE
        • (5강) BERT 기반 단일 문장 분류 모델 학습
        • (4강) 한국어 BERT 언어 모델 학습
        • [NLP] 문장 내 개체간 관계 추출
        • (3강) BERT 언어모델 소개
        • (2강) 자연어의 전처리
        • (1강) 인공지능과 자연어 처리
      • [U]Stage-CV
      • [U]Stage-NLP
        • 7W Retrospective
        • (10강) Advanced Self-supervised Pre-training Models
        • (09강) Self-supervised Pre-training Models
        • (08강) Transformer (2)
        • (07강) Transformer (1)
        • 6W Retrospective
        • (06강) Beam Search and BLEU score
        • (05강) Sequence to Sequence with Attention
        • (04강) LSTM and GRU
        • (03강) Recurrent Neural Network and Language Modeling
        • (02강) Word Embedding
        • (01강) Intro to NLP, Bag-of-Words
        • [필수 과제 4] Preprocessing for NMT Model
        • [필수 과제 3] Subword-level Language Model
        • [필수 과제2] RNN-based Language Model
        • [선택 과제] BERT Fine-tuning with transformers
        • [필수 과제] Data Preprocessing
      • Mask Wear Image Classification
        • 5W Retrospective
        • Report_Level1_6
        • Performance | Review
        • DAY 11 : HardVoting | MultiLabelClassification
        • DAY 10 : Cutmix
        • DAY 9 : Loss Function
        • DAY 8 : Baseline
        • DAY 7 : Class Imbalance | Stratification
        • DAY 6 : Error Fix
        • DAY 5 : Facenet | Save
        • DAY 4 : VIT | F1_Loss | LrScheduler
        • DAY 3 : DataSet/Lodaer | EfficientNet
        • DAY 2 : Labeling
        • DAY 1 : EDA
        • 2_EDA Analysis
      • [P]Stage-1
        • 4W Retrospective
        • (10강) Experiment Toolkits & Tips
        • (9강) Ensemble
        • (8강) Training & Inference 2
        • (7강) Training & Inference 1
        • (6강) Model 2
        • (5강) Model 1
        • (4강) Data Generation
        • (3강) Dataset
        • (2강) Image Classification & EDA
        • (1강) Competition with AI Stages!
      • [U]Stage-3
        • 3W Retrospective
        • PyTorch
          • (10강) PyTorch Troubleshooting
          • (09강) Hyperparameter Tuning
          • (08강) Multi-GPU 학습
          • (07강) Monitoring tools for PyTorch
          • (06강) 모델 불러오기
          • (05강) Dataset & Dataloader
          • (04강) AutoGrad & Optimizer
          • (03강) PyTorch 프로젝트 구조 이해하기
          • (02강) PyTorch Basics
          • (01강) Introduction to PyTorch
      • [U]Stage-2
        • 2W Retrospective
        • DL Basic
          • (10강) Generative Models 2
          • (09강) Generative Models 1
          • (08강) Sequential Models - Transformer
          • (07강) Sequential Models - RNN
          • (06강) Computer Vision Applications
          • (05강) Modern CNN - 1x1 convolution의 중요성
          • (04강) Convolution은 무엇인가?
          • (03강) Optimization
          • (02강) 뉴럴 네트워크 - MLP (Multi-Layer Perceptron)
          • (01강) 딥러닝 기본 용어 설명 - Historical Review
        • Assignment
          • [필수 과제] Multi-headed Attention Assignment
          • [필수 과제] LSTM Assignment
          • [필수 과제] CNN Assignment
          • [필수 과제] Optimization Assignment
          • [필수 과제] MLP Assignment
      • [U]Stage-1
        • 1W Retrospective
        • AI Math
          • (AI Math 10강) RNN 첫걸음
          • (AI Math 9강) CNN 첫걸음
          • (AI Math 8강) 베이즈 통계학 맛보기
          • (AI Math 7강) 통계학 맛보기
          • (AI Math 6강) 확률론 맛보기
          • (AI Math 5강) 딥러닝 학습방법 이해하기
          • (AI Math 4강) 경사하강법 - 매운맛
          • (AI Math 3강) 경사하강법 - 순한맛
          • (AI Math 2강) 행렬이 뭐예요?
          • (AI Math 1강) 벡터가 뭐예요?
        • Python
          • (Python 7-2강) pandas II
          • (Python 7-1강) pandas I
          • (Python 6강) numpy
          • (Python 5-2강) Python data handling
          • (Python 5-1강) File / Exception / Log Handling
          • (Python 4-2강) Module and Project
          • (Python 4-1강) Python Object Oriented Programming
          • (Python 3-2강) Pythonic code
          • (Python 3-1강) Python Data Structure
          • (Python 2-4강) String and advanced function concept
          • (Python 2-3강) Conditionals and Loops
          • (Python 2-2강) Function and Console I/O
          • (Python 2-1강) Variables
          • (Python 1-3강) 파이썬 코딩 환경
          • (Python 1-2강) 파이썬 개요
          • (Python 1-1강) Basic computer class for newbies
        • Assignment
          • [선택 과제 3] Maximum Likelihood Estimate
          • [선택 과제 2] Backpropagation
          • [선택 과제 1] Gradient Descent
          • [필수 과제 5] Morsecode
          • [필수 과제 4] Baseball
          • [필수 과제 3] Text Processing 2
          • [필수 과제 2] Text Processing 1
          • [필수 과제 1] Basic Math
    • 딥러닝 CNN 완벽 가이드 - Fundamental 편
      • 종합 실습 2 - 캐글 Plant Pathology(나무잎 병 진단) 경연 대회
      • 종합 실습 1 - 120종의 Dog Breed Identification 모델 최적화
      • 사전 훈련 모델의 미세 조정 학습과 다양한 Learning Rate Scheduler의 적용
      • Advanced CNN 모델 파헤치기 - ResNet 상세와 EfficientNet 개요
      • Advanced CNN 모델 파헤치기 - AlexNet, VGGNet, GoogLeNet
      • Albumentation을 이용한 Augmentation기법과 Keras Sequence 활용하기
      • 사전 훈련 CNN 모델의 활용과 Keras Generator 메커니즘 이해
      • 데이터 증강의 이해 - Keras ImageDataGenerator 활용
      • CNN 모델 구현 및 성능 향상 기본 기법 적용하기
    • AI School 1st
    • 현업 실무자에게 배우는 Kaggle 머신러닝 입문
    • 파이썬 딥러닝 파이토치
  • TIL : Python & Math
    • Do It! 장고+부트스트랩: 파이썬 웹개발의 정석
      • Relations - 다대다 관계
      • Relations - 다대일 관계
      • 템플릿 파일 모듈화 하기
      • TDD (Test Driven Development)
      • template tags & 조건문
      • 정적 파일(static files) & 미디어 파일(media files)
      • FBV (Function Based View)와 CBV (Class Based View)
      • Django 입문하기
      • 부트스트랩
      • 프론트엔드 기초다지기 (HTML, CSS, JS)
      • 들어가기 + 환경설정
    • Algorithm
      • Programmers
        • Level1
          • 소수 만들기
          • 숫자 문자열과 영단어
          • 자연수 뒤집어 배열로 만들기
          • 정수 내림차순으로 배치하기
          • 정수 제곱근 판별
          • 제일 작은 수 제거하기
          • 직사각형 별찍기
          • 짝수와 홀수
          • 체육복
          • 최대공약수와 최소공배수
          • 콜라츠 추측
          • 크레인 인형뽑기 게임
          • 키패드 누르기
          • 평균 구하기
          • 폰켓몬
          • 하샤드 수
          • 핸드폰 번호 가리기
          • 행렬의 덧셈
        • Level2
          • 숫자의 표현
          • 순위 검색
          • 수식 최대화
          • 소수 찾기
          • 소수 만들기
          • 삼각 달팽이
          • 문자열 압축
          • 메뉴 리뉴얼
          • 더 맵게
          • 땅따먹기
          • 멀쩡한 사각형
          • 괄호 회전하기
          • 괄호 변환
          • 구명보트
          • 기능 개발
          • 뉴스 클러스터링
          • 다리를 지나는 트럭
          • 다음 큰 숫자
          • 게임 맵 최단거리
          • 거리두기 확인하기
          • 가장 큰 정사각형 찾기
          • H-Index
          • JadenCase 문자열 만들기
          • N개의 최소공배수
          • N진수 게임
          • 가장 큰 수
          • 124 나라의 숫자
          • 2개 이하로 다른 비트
          • [3차] 파일명 정렬
          • [3차] 압축
          • 줄 서는 방법
          • [3차] 방금 그곡
          • 거리두기 확인하기
        • Level3
          • 매칭 점수
          • 외벽 점검
          • 기지국 설치
          • 숫자 게임
          • 110 옮기기
          • 광고 제거
          • 길 찾기 게임
          • 셔틀버스
          • 단속카메라
          • 표 편집
          • N-Queen
          • 징검다리 건너기
          • 최고의 집합
          • 합승 택시 요금
          • 거스름돈
          • 하노이의 탑
          • 멀리 뛰기
          • 모두 0으로 만들기
        • Level4
    • Head First Python
    • 데이터 분석을 위한 SQL
    • 단 두 장의 문서로 데이터 분석과 시각화 뽀개기
    • Linear Algebra(Khan Academy)
    • 인공지능을 위한 선형대수
    • Statistics110
  • TIL : etc
    • [따배런] Kubernetes
    • [따배런] Docker
      • 2. 도커 설치 실습 1 - 학습편(준비물/실습 유형 소개)
      • 1. 컨테이너와 도커의 이해 - 컨테이너를 쓰는이유 / 일반프로그램과 컨테이너프로그램의 차이점
      • 0. 드디어 찾아온 Docker 강의! 왕초보에서 도커 마스터로 - OT
    • CoinTrading
      • [가상 화폐 자동 매매 프로그램] 백테스팅 : 간단한 테스팅
    • Gatsby
      • 01 깃북 포기 선언
  • TIL : Project
    • Mask Wear Image Classification
    • Project. GARIGO
  • 2021 TIL
    • CHANGED
    • JUN
      • 30 Wed
      • 29 Tue
      • 28 Mon
      • 27 Sun
      • 26 Sat
      • 25 Fri
      • 24 Thu
      • 23 Wed
      • 22 Tue
      • 21 Mon
      • 20 Sun
      • 19 Sat
      • 18 Fri
      • 17 Thu
      • 16 Wed
      • 15 Tue
      • 14 Mon
      • 13 Sun
      • 12 Sat
      • 11 Fri
      • 10 Thu
      • 9 Wed
      • 8 Tue
      • 7 Mon
      • 6 Sun
      • 5 Sat
      • 4 Fri
      • 3 Thu
      • 2 Wed
      • 1 Tue
    • MAY
      • 31 Mon
      • 30 Sun
      • 29 Sat
      • 28 Fri
      • 27 Thu
      • 26 Wed
      • 25 Tue
      • 24 Mon
      • 23 Sun
      • 22 Sat
      • 21 Fri
      • 20 Thu
      • 19 Wed
      • 18 Tue
      • 17 Mon
      • 16 Sun
      • 15 Sat
      • 14 Fri
      • 13 Thu
      • 12 Wed
      • 11 Tue
      • 10 Mon
      • 9 Sun
      • 8 Sat
      • 7 Fri
      • 6 Thu
      • 5 Wed
      • 4 Tue
      • 3 Mon
      • 2 Sun
      • 1 Sat
    • APR
      • 30 Fri
      • 29 Thu
      • 28 Wed
      • 27 Tue
      • 26 Mon
      • 25 Sun
      • 24 Sat
      • 23 Fri
      • 22 Thu
      • 21 Wed
      • 20 Tue
      • 19 Mon
      • 18 Sun
      • 17 Sat
      • 16 Fri
      • 15 Thu
      • 14 Wed
      • 13 Tue
      • 12 Mon
      • 11 Sun
      • 10 Sat
      • 9 Fri
      • 8 Thu
      • 7 Wed
      • 6 Tue
      • 5 Mon
      • 4 Sun
      • 3 Sat
      • 2 Fri
      • 1 Thu
    • MAR
      • 31 Wed
      • 30 Tue
      • 29 Mon
      • 28 Sun
      • 27 Sat
      • 26 Fri
      • 25 Thu
      • 24 Wed
      • 23 Tue
      • 22 Mon
      • 21 Sun
      • 20 Sat
      • 19 Fri
      • 18 Thu
      • 17 Wed
      • 16 Tue
      • 15 Mon
      • 14 Sun
      • 13 Sat
      • 12 Fri
      • 11 Thu
      • 10 Wed
      • 9 Tue
      • 8 Mon
      • 7 Sun
      • 6 Sat
      • 5 Fri
      • 4 Thu
      • 3 Wed
      • 2 Tue
      • 1 Mon
    • FEB
      • 28 Sun
      • 27 Sat
      • 26 Fri
      • 25 Thu
      • 24 Wed
      • 23 Tue
      • 22 Mon
      • 21 Sun
      • 20 Sat
      • 19 Fri
      • 18 Thu
      • 17 Wed
      • 16 Tue
      • 15 Mon
      • 14 Sun
      • 13 Sat
      • 12 Fri
      • 11 Thu
      • 10 Wed
      • 9 Tue
      • 8 Mon
      • 7 Sun
      • 6 Sat
      • 5 Fri
      • 4 Thu
      • 3 Wed
      • 2 Tue
      • 1 Mon
    • JAN
      • 31 Sun
      • 30 Sat
      • 29 Fri
      • 28 Thu
      • 27 Wed
      • 26 Tue
      • 25 Mon
      • 24 Sun
      • 23 Sat
      • 22 Fri
      • 21 Thu
      • 20 Wed
      • 19 Tue
      • 18 Mon
      • 17 Sun
      • 16 Sat
      • 15 Fri
      • 14 Thu
      • 13 Wed
      • 12 Tue
      • 11 Mon
      • 10 Sun
      • 9 Sat
      • 8 Fri
      • 7 Thu
      • 6 Wed
      • 5 Tue
      • 4 Mon
      • 3 Sun
      • 2 Sat
      • 1 Fri
  • 2020 TIL
    • DEC
      • 31 Thu
      • 30 Wed
      • 29 Tue
      • 28 Mon
      • 27 Sun
      • 26 Sat
      • 25 Fri
      • 24 Thu
      • 23 Wed
      • 22 Tue
      • 21 Mon
      • 20 Sun
      • 19 Sat
      • 18 Fri
      • 17 Thu
      • 16 Wed
      • 15 Tue
      • 14 Mon
      • 13 Sun
      • 12 Sat
      • 11 Fri
      • 10 Thu
      • 9 Wed
      • 8 Tue
      • 7 Mon
      • 6 Sun
      • 5 Sat
      • 4 Fri
      • 3 Tue
      • 2 Wed
      • 1 Tue
    • NOV
      • 30 Mon
Powered by GitBook
On this page
  • Multi-Headed Attention
  • Scaled Dot-Product Attention (SDPA)
  • Multi-Headed Attention (MHA)

Was this helpful?

  1. TIL : ML
  2. Boostcamp 2st
  3. [U]Stage-2
  4. Assignment

[필수 과제] Multi-headed Attention Assignment

PreviousAssignmentNext[필수 과제] LSTM Assignment

Last updated 3 years ago

Was this helpful?

Multi-Headed Attention

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))
PyTorch version:[1.9.0+cu102].
device:[cuda:0].

Scaled Dot-Product Attention (SDPA)

Scaled Dot Attention은 self(single)-attention이고, 이후에 multi attention이 나올 것임

class ScaledDotProductAttention(nn.Module):
    def forward(self,Q,K,V,mask=None):
        d_K = K.size()[-1] # key dimension
        scores = Q.matmul(K.transpose(-2,-1)) / np.sqrt(d_K)
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        attention = F.softmax(scores,dim=-1)
        out = attention.matmul(V)
        return out,attention
  • 2 : Q, K, V 벡터를 입력받는다. 정확히는 이것이 Batch가 되어서 들어오게 된다.

  • 3 : Key dimension을 찾음. 왜? 점수를 square값으로 나눠야 하니까

  • 4 : 점수를 계산

  • 6 : softmax로 attention값 구하기

  • 7 : Value벡터와 attention곱 구하기

# Demo run of scaled dot product attention 
SPDA = ScaledDotProductAttention()
n_batch,d_K,d_V = 3,128,256 # d_K(=d_Q) does not necessarily be equal to d_V
n_Q,n_K,n_V = 30,50,50
Q = torch.rand(n_batch,n_Q,d_K)
K = torch.rand(n_batch,n_K,d_K)
V = torch.rand(n_batch,n_V,d_V)
out,attention = SPDA.forward(Q,K,V,mask=None)
def sh(x): return str(x.shape)[11:-1] 
print ("SDPA: Q%s K%s V%s => out%s attention%s"%
       (sh(Q),sh(K),sh(V),sh(out),sh(attention)))
SDPA: Q[3, 30, 128] K[3, 50, 128] V[3, 50, 256] => out[3, 30, 256] attention[3, 30, 50]
  • 3

    • n_batch : 단어의 개수

    • d_K : K벡터의 차원

    • d_V : V벡터의 차원

    • V벡터는 K벡터와 차원이 달라도 된다.

  • 4: 각각의 벡터의 총 개수

    • 쿼리벡터의 개수와 키벡터의 개수가 달라도 된다. 개수가 달라도 서로의 interaction을 계산할 수 있다.

# It supports 'multi-headed' attention
n_batch,n_head,d_K,d_V = 3,5,128,256
n_Q,n_K,n_V = 30,50,50 # n_K and n_V should be the same
Q = torch.rand(n_batch,n_head,n_Q,d_K)
K = torch.rand(n_batch,n_head,n_K,d_K)
V = torch.rand(n_batch,n_head,n_V,d_V)
out,attention = SPDA.forward(Q,K,V,mask=None)
# out: [n_batch x n_head x n_Q x d_V]
# attention: [n_batch x n_head x n_Q x n_K] 
def sh(x): return str(x.shape)[11:-1] 
print ("(Multi-Headed) SDPA: Q%s K%s V%s => out%s attention%s"%
       (sh(Q),sh(K),sh(V),sh(out),sh(attention)))
(Multi-Headed) SDPA: Q[3, 5, 30, 128] K[3, 5, 50, 128] V[3, 5, 50, 256] => out[3, 5, 30, 256] attention[3, 5, 30, 50]
  • 멀티헤드에서는 각각의 벡터를 몇개할지가 2번째 인자자리에 추가되었다.

Multi-Headed Attention (MHA)

class MultiHeadedAttention(nn.Module):
    def __init__(self,d_feat=128,n_head=5,actv=F.relu,USE_BIAS=True,dropout_p=0.1,device=None):
        """
        :param d_feat: feature dimension
        :param n_head: number of heads
        :param actv: activation after each linear layer
        :param USE_BIAS: whether to use bias
        :param dropout_p: dropout rate
        :device: which device to use (e.g., cuda:0)
        """
        super(MultiHeadedAttention,self).__init__()
        if (d_feat%n_head) != 0:
            raise ValueError("d_feat(%d) should be divisible by b_head(%d)"%(d_feat,n_head)) 
        self.d_feat = d_feat
        self.n_head = n_head
        self.d_head = self.d_feat // self.n_head
        self.actv = actv
        self.USE_BIAS = USE_BIAS
        self.dropout_p = dropout_p # prob. of zeroed

        self.lin_Q = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)
        self.lin_K = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)
        self.lin_V = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)
        self.lin_O = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)

        self.dropout = nn.Dropout(p=self.dropout_p)
  • 2 : dropout이 multihead 인자로 들어가게된다. 논문에는 설명이 안되어있는데 모든 코드에 쓴다

  • 21-24 : Q, K ,V 벡터를 구하는 신경망을 구성하고 나오는 결과값을 가공해주는 Output도 정의해준다.

    def forward(self,Q,K,V,mask=None):
        """
        :param Q: [n_batch, n_Q, d_feat]
        :param K: [n_batch, n_K, d_feat]
        :param V: [n_batch, n_V, d_feat] <= n_K and n_V must be the same 
        :param mask: 
        """
        n_batch = Q.shape[0]
        Q_feat = self.lin_Q(Q) 
        K_feat = self.lin_K(K) 
        V_feat = self.lin_V(V)
        # Q_feat: [n_batch, n_Q, d_feat]
        # K_feat: [n_batch, n_K, d_feat]
        # V_feat: [n_batch, n_V, d_feat]

        # Multi-head split of Q, K, and V (d_feat = n_head*d_head)
        Q_split = Q_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)
        K_split = K_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)
        V_split = V_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)
        # Q_split: [n_batch, n_head, n_Q, d_head]
        # K_split: [n_batch, n_head, n_K, d_head]
        # V_split: [n_batch, n_head, n_V, d_head]

        # Multi-Headed Attention
        d_K = K.size()[-1] # key dimension
        scores = torch.matmul(Q_split, K_split.permute(0, 1, 3, 2)) / np.sqrt(d_K)
        if mask is not None:
            scores = scores.masked_fill(mask==0,-1e9)
        attention = torch.softmax(scores,dim=-1)
        x_raw = torch.matmul(self.dropout(attention),V_split) # dropout is NOT mentioned in the paper
        # attention: [n_batch, n_head, n_Q, n_K]
        # x_raw: [n_batch, n_head, n_Q, d_head]

        # Reshape x
        x_rsh1 = x_raw.permute(0,2,1,3).contiguous()
        # x_rsh1: [n_batch, n_Q, n_head, d_head]
        x_rsh2 = x_rsh1.view(n_batch,-1,self.d_feat)
        # x_rsh2: [n_batch, n_Q, d_feat]

        # Linear
        x = self.lin_O(x_rsh2)
        # x: [n_batch, n_Q, d_feat]
        out = {'Q_feat':Q_feat,'K_feat':K_feat,'V_feat':V_feat,
               'Q_split':Q_split,'K_split':K_split,'V_split':V_split,
               'scores':scores,'attention':attention,
               'x_raw':x_raw,'x_rsh1':x_rsh1,'x_rsh2':x_rsh2,'x':x}
        return out
  • 9-11 : 각각의 벡터를 신경망에 넣는다

  • 17-19 : 그리고 이 벡터를 조각조각 내준다.

# Self-Attention Layer
n_batch = 128
n_src   = 32
d_feat  = 200
n_head  = 5
src = torch.rand(n_batch,n_src,d_feat)
self_attention = MultiHeadedAttention(
    d_feat=d_feat,n_head=n_head,actv=F.relu,USE_BIAS=True,dropout_p=0.1,device=device)
out = self_attention.forward(src,src,src,mask=None)

Q_feat,K_feat,V_feat = out['Q_feat'],out['K_feat'],out['V_feat']
Q_split,K_split,V_split = out['Q_split'],out['K_split'],out['V_split']
scores,attention = out['scores'],out['attention']
x_raw,x_rsh1,x_rsh2,x = out['x_raw'],out['x_rsh1'],out['x_rsh2'],out['x']

# Print out shapes
def sh(_x): return str(_x.shape)[11:-1] 
print ("Input src:\t%s  \t= [n_batch, n_src, d_feat]"%(sh(src)))
print ()
print ("Q_feat:   \t%s  \t= [n_batch, n_src, d_feat]"%(sh(Q_feat)))
print ("K_feat:   \t%s  \t= [n_batch, n_src, d_feat]"%(sh(K_feat)))
print ("V_feat:   \t%s  \t= [n_batch, n_src, d_feat]"%(sh(V_feat)))
print ()
print ("Q_split:  \t%s  \t= [n_batch, n_head, n_src, d_head]"%(sh(Q_split)))
print ("K_split:  \t%s  \t= [n_batch, n_head, n_src, d_head]"%(sh(K_split)))
print ("V_split:  \t%s  \t= [n_batch, n_head, n_src, d_head]"%(sh(V_split)))
print ()
print ("scores:   \t%s  \t= [n_batch, n_head, n_src, n_src]"%(sh(scores)))
print ("attention:\t%s  \t= [n_batch, n_head, n_src, n_src]"%(sh(attention)))
print ()
print ("x_raw:    \t%s  \t= [n_batch, n_head, n_src, d_head]"%(sh(x_raw)))
print ("x_rsh1:   \t%s  \t= [n_batch, n_src, n_head, d_head]"%(sh(x_rsh1)))
print ("x_rsh2:   \t%s  \t= [n_batch, n_src, d_feat]"%(sh(x_rsh2)))
print ()
print ("Output x: \t%s  \t= [n_batch, n_src, d_feat]"%(sh(x)))
Input src:	[128, 32, 200]  	= [n_batch, n_src, d_feat]

Q_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
K_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]
V_feat:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]

Q_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
K_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
V_split:  	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]

scores:   	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]
attention:	[128, 5, 32, 32]  	= [n_batch, n_head, n_src, n_src]

x_raw:    	[128, 5, 32, 40]  	= [n_batch, n_head, n_src, d_head]
x_rsh1:   	[128, 32, 5, 40]  	= [n_batch, n_src, n_head, d_head]
x_rsh2:   	[128, 32, 200]  	= [n_batch, n_src, d_feat]

Output x: 	[128, 32, 200]  	= [n_batch, n_src, d_feat]