23. LSTM과 GRU셀

  • 이번 시간에는 고급 순환층인 LSTM과 GRU에 대해 알아보자. 이 층들은 이전시간에 배운 SimpleRNN보다 계산이 훨씬 복잡하다. 하지만 성능이 뛰어나서 순환신경망에 많이 채택되고 있다.
  • 일반적으로 기본 순환층은 긴 시퀀스를 학습하기 어렵다. 시퀀스가 길수록 순환되는 은닉 상태에 담긴 정보가 점차 희석되기 때문이다. 따라서 멀리 떨어져 있는 단어 정보를 인식하는 데 어려울 수 있다. 이를 해결하기 위해 LSTM과 GRU셀이 발명 되었다.

1. LSTM 구조

  • LSTM은 Long-Short-Term-Memory의 약자이다. 단기 기억을 오래 기억하기 위해 고안되었다. LSTM은 구조가 복잡하기 때문에 단계적으로 설명하도록 한다. 기본 개념은 동일하다.
  • 가중치와 입력을 곱하고 절편을 더해 활성함수를 통과시킨다. 이런 구조를 여러개 가지고 있다. 또한 계산 결과는 다음 타임스텝에 재사용된다.
  • 먼저 은닉상태를 만드는 방법을 살펴보자. 은닉 상태는 입력과 이전 타임스텝의 은닉 상태를 가중치에 먼저 곱한 후 활성화 함수를 통과시켜 다음 은닉상태를 만든다. 이떄는 기본 순환층과달리 시그모이드 함수를 사용한다. 또 하이퍼볼릭 탄잰트 함수를 통과한 어느 값과 곱해져 은닉 상태를 만든다.
  • LSTM에는 순환되는 상태가 2개이다. 은닉 상태 말고 셀 상태라고 부르는 값이 또 있다. 은닉 상태와 달리 셀 상태는 다음 층으로 전달되지 않고 LSTM셀에서 순환만 되는 값이다.
  • LSTM에는 총 4개의 셀이 존재한다. 복잡한 셀 계산을 직접 할 필요없이 케라스에는 LSTM 클래스가 준비되어있다.

2. LSTM 신경망 훈련하기

from tensorflow import keras
from tensorflow.keras.datasets import imdb
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
 (train_input, train_target), (test_input, test_target) = imdb.load_data(num_words=500)
train_input, val_input, train_target, val_target = train_test_split(train_input, train_target,
                                                                  test_size=0.2, random_state=42)
  • 필요한 데이터를 불러오고 훈련 세트와 테스트세트 그리고 다시 훈련 세트를 검증 세트와 훈련 세트로 나눴다.
  • 그 다음 케라스의 pad_sequences() 함수로 각 셈플의 길이를 100에 맞추고 부족할 떄는 패딩을 추가한다,
train_seq = pad_sequences(train_input, maxlen=100)
val_seq = pad_sequences(val_input, maxlen=100)
  • 이제 LSTM 셀을 사용한 순환층을 만든다. 이전에 만든 SimpleRNN에서 SimpleRNN을 LSTM으로 바꿔주면된다.
model = keras.Sequential()
model.add(keras.layers.Embedding(500, 16, input_length=100))
model.add(keras.layers.LSTM(8))
model.add(keras.layers.Dense(1, activation='sigmoid'))
2021-10-15 15:45:34.095087: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  • 모델 구조를 살펴보자.
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 100, 16)           8000      
_________________________________________________________________
lstm (LSTM)                  (None, 8)                 800       
_________________________________________________________________
dense (Dense)                (None, 1)                 9         
=================================================================
Total params: 8,809
Trainable params: 8,809
Non-trainable params: 0
_________________________________________________________________
  • 이전 시간의 SimpleRNN 클래스의 모델 파라미터는 200개였다. LSTM 셀에는 작은 셀이 4개 존재한다. 그래서 정확히 4배가 늘어난 800이 되었다.
  • 모델을 컴파일 하고 훈련을 진행해보도록 하겠다. 배치 크기는 64, 에포크횟수는 100으로 지정한다. 체크포인트와 조기종료를 위한 코드도 동일하다.
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model.compile(optimizer=rmsprop, loss='binary_crossentropy', 
              metrics=['accuracy'])

checkpoint_cb = keras.callbacks.ModelCheckpoint('best-lstm-model.h5', 
                                                save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=3,
                                                  restore_best_weights=True)

history = model.fit(train_seq, train_target, epochs=100, batch_size=64,
                   validation_data=(val_seq, val_target),
                   callbacks=[checkpoint_cb, early_stopping_cb])
2021-10-15 15:45:34.498346: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Epoch 1/100
313/313 [==============================] - 10s 26ms/step - loss: 0.6927 - accuracy: 0.5314 - val_loss: 0.6920 - val_accuracy: 0.5662
Epoch 2/100
313/313 [==============================] - 7s 24ms/step - loss: 0.6906 - accuracy: 0.6061 - val_loss: 0.6893 - val_accuracy: 0.6270
Epoch 3/100
313/313 [==============================] - 7s 24ms/step - loss: 0.6858 - accuracy: 0.6496 - val_loss: 0.6828 - val_accuracy: 0.6524
Epoch 4/100
313/313 [==============================] - 8s 24ms/step - loss: 0.6737 - accuracy: 0.6758 - val_loss: 0.6631 - val_accuracy: 0.6828
Epoch 5/100
313/313 [==============================] - 7s 24ms/step - loss: 0.6189 - accuracy: 0.7182 - val_loss: 0.5647 - val_accuracy: 0.7258
Epoch 6/100
313/313 [==============================] - 7s 24ms/step - loss: 0.5455 - accuracy: 0.7382 - val_loss: 0.5359 - val_accuracy: 0.7490
Epoch 7/100
313/313 [==============================] - 7s 24ms/step - loss: 0.5219 - accuracy: 0.7558 - val_loss: 0.5165 - val_accuracy: 0.7640
Epoch 8/100
313/313 [==============================] - 7s 24ms/step - loss: 0.5035 - accuracy: 0.7688 - val_loss: 0.5023 - val_accuracy: 0.7690
Epoch 9/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4888 - accuracy: 0.7774 - val_loss: 0.4891 - val_accuracy: 0.7768
Epoch 10/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4758 - accuracy: 0.7850 - val_loss: 0.4792 - val_accuracy: 0.7838
Epoch 11/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4651 - accuracy: 0.7919 - val_loss: 0.4702 - val_accuracy: 0.7868
Epoch 12/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4560 - accuracy: 0.7965 - val_loss: 0.4634 - val_accuracy: 0.7888
Epoch 13/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4485 - accuracy: 0.7985 - val_loss: 0.4588 - val_accuracy: 0.7934
Epoch 14/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4421 - accuracy: 0.8038 - val_loss: 0.4531 - val_accuracy: 0.7930
Epoch 15/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4367 - accuracy: 0.8051 - val_loss: 0.4485 - val_accuracy: 0.7954
Epoch 16/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4323 - accuracy: 0.8073 - val_loss: 0.4456 - val_accuracy: 0.7964
Epoch 17/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4285 - accuracy: 0.8080 - val_loss: 0.4423 - val_accuracy: 0.7990
Epoch 18/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4253 - accuracy: 0.8091 - val_loss: 0.4414 - val_accuracy: 0.7948
Epoch 19/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4228 - accuracy: 0.8103 - val_loss: 0.4392 - val_accuracy: 0.7970
Epoch 20/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4207 - accuracy: 0.8109 - val_loss: 0.4380 - val_accuracy: 0.7990
Epoch 21/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4184 - accuracy: 0.8102 - val_loss: 0.4361 - val_accuracy: 0.8010
Epoch 22/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4168 - accuracy: 0.8119 - val_loss: 0.4367 - val_accuracy: 0.7996
Epoch 23/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4155 - accuracy: 0.8111 - val_loss: 0.4344 - val_accuracy: 0.7988
Epoch 24/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4145 - accuracy: 0.8116 - val_loss: 0.4339 - val_accuracy: 0.7976
Epoch 25/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4135 - accuracy: 0.8124 - val_loss: 0.4355 - val_accuracy: 0.7956
Epoch 26/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4123 - accuracy: 0.8135 - val_loss: 0.4329 - val_accuracy: 0.7988
Epoch 27/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4119 - accuracy: 0.8134 - val_loss: 0.4330 - val_accuracy: 0.7976
Epoch 28/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4112 - accuracy: 0.8127 - val_loss: 0.4325 - val_accuracy: 0.7986
Epoch 29/100
313/313 [==============================] - 7s 23ms/step - loss: 0.4103 - accuracy: 0.8124 - val_loss: 0.4319 - val_accuracy: 0.7996
Epoch 30/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4098 - accuracy: 0.8126 - val_loss: 0.4330 - val_accuracy: 0.8038
Epoch 31/100
313/313 [==============================] - 7s 24ms/step - loss: 0.4092 - accuracy: 0.8145 - val_loss: 0.4347 - val_accuracy: 0.7958
Epoch 32/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4083 - accuracy: 0.8138 - val_loss: 0.4325 - val_accuracy: 0.8014
  • 훈련 손실과 검증 손실을 그래프로 그려보면 아래와 같다.
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train', 'val'])
plt.show()

output_16_0

  • 결과는 위와 같다. 이전에 배웠던 드롭아웃을 순환층에 적용해보자.

3. 순환층에 드롭아웃 적용하기

  • 완전 연결 신경망과 합성곱 신경망에서는 Dropout 클래스를 사용해 드롭아웃을 적용했다. 이를 통해 모델이 훈련 세트에 너무 과대적합되는 것을 막았다. 순환층은 자체적으로 드롭아웃 기능을 제공한다. SimpleRNN과 LSTM클래스 모두 dropout 매개변수와 recurrent_dropout매개변수를 가지고 있다.
  • dropout 매개변수는 셀의 입력에 드롭아웃을 적용하고 recurrent_dropout은 순환되는 은닉상태에 드롭아웃을 적용한다. 하지만 기술적인 문제로 인해 recurrent_dropout을 사용하면 GPU를 사용해 모델을 훈련하지 못한다. 이 떄문에 모델의 훈련 속도가 크게 느려진다. 이번에는 dropout만 사용해보도록 하자.
model2 = keras.Sequential()
model2.add(keras.layers.Embedding(500, 16, input_length=100))
model2.add(keras.layers.LSTM(8, dropout=0.3))
model2.add(keras.layers.Dense(1, activation='sigmoid'))
model2.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (None, 100, 16)           8000      
_________________________________________________________________
lstm_1 (LSTM)                (None, 8)                 800       
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 9         
=================================================================
Total params: 8,809
Trainable params: 8,809
Non-trainable params: 0
_________________________________________________________________
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model2.compile(optimizer=rmsprop, loss='binary_crossentropy', 
              metrics=['accuracy'])

checkpoint_cb = keras.callbacks.ModelCheckpoint('best-dropout-model.h5', 
                                                save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=3,
                                                  restore_best_weights=True)

history = model2.fit(train_seq, train_target, epochs=100, batch_size=64,
                   validation_data=(val_seq, val_target),
                   callbacks=[checkpoint_cb, early_stopping_cb])
Epoch 1/100
313/313 [==============================] - 9s 25ms/step - loss: 0.6924 - accuracy: 0.5368 - val_loss: 0.6915 - val_accuracy: 0.5540
Epoch 2/100
313/313 [==============================] - 8s 24ms/step - loss: 0.6882 - accuracy: 0.6065 - val_loss: 0.6835 - val_accuracy: 0.6858
Epoch 3/100
313/313 [==============================] - 8s 24ms/step - loss: 0.6516 - accuracy: 0.6911 - val_loss: 0.6115 - val_accuracy: 0.6950
Epoch 4/100
313/313 [==============================] - 8s 25ms/step - loss: 0.6008 - accuracy: 0.6999 - val_loss: 0.5880 - val_accuracy: 0.7170
Epoch 5/100
313/313 [==============================] - 8s 25ms/step - loss: 0.5796 - accuracy: 0.7225 - val_loss: 0.5699 - val_accuracy: 0.7296
Epoch 6/100
313/313 [==============================] - 8s 24ms/step - loss: 0.5616 - accuracy: 0.7375 - val_loss: 0.5516 - val_accuracy: 0.7470
Epoch 7/100
313/313 [==============================] - 8s 24ms/step - loss: 0.5459 - accuracy: 0.7483 - val_loss: 0.5390 - val_accuracy: 0.7600
Epoch 8/100
313/313 [==============================] - 8s 24ms/step - loss: 0.5296 - accuracy: 0.7576 - val_loss: 0.5232 - val_accuracy: 0.7580
Epoch 9/100
313/313 [==============================] - 8s 24ms/step - loss: 0.5159 - accuracy: 0.7653 - val_loss: 0.5123 - val_accuracy: 0.7708
Epoch 10/100
313/313 [==============================] - 8s 24ms/step - loss: 0.5027 - accuracy: 0.7730 - val_loss: 0.4996 - val_accuracy: 0.7718
Epoch 11/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4910 - accuracy: 0.7807 - val_loss: 0.4894 - val_accuracy: 0.7776
Epoch 12/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4808 - accuracy: 0.7840 - val_loss: 0.4806 - val_accuracy: 0.7836
Epoch 13/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4705 - accuracy: 0.7886 - val_loss: 0.4736 - val_accuracy: 0.7842
Epoch 14/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4652 - accuracy: 0.7908 - val_loss: 0.4732 - val_accuracy: 0.7822
Epoch 15/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4581 - accuracy: 0.7958 - val_loss: 0.4607 - val_accuracy: 0.7902
Epoch 16/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4504 - accuracy: 0.7976 - val_loss: 0.4554 - val_accuracy: 0.7962
Epoch 17/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4457 - accuracy: 0.7997 - val_loss: 0.4512 - val_accuracy: 0.7990
Epoch 18/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4403 - accuracy: 0.8020 - val_loss: 0.4473 - val_accuracy: 0.7990
Epoch 19/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4361 - accuracy: 0.8031 - val_loss: 0.4443 - val_accuracy: 0.8000
Epoch 20/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4344 - accuracy: 0.8052 - val_loss: 0.4449 - val_accuracy: 0.7954
Epoch 21/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4325 - accuracy: 0.8024 - val_loss: 0.4425 - val_accuracy: 0.7968
Epoch 22/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4289 - accuracy: 0.8061 - val_loss: 0.4410 - val_accuracy: 0.8008
Epoch 23/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4268 - accuracy: 0.8067 - val_loss: 0.4371 - val_accuracy: 0.7988
Epoch 24/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4245 - accuracy: 0.8067 - val_loss: 0.4352 - val_accuracy: 0.7978
Epoch 25/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4229 - accuracy: 0.8070 - val_loss: 0.4333 - val_accuracy: 0.8006
Epoch 26/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4206 - accuracy: 0.8076 - val_loss: 0.4369 - val_accuracy: 0.7976
Epoch 27/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4198 - accuracy: 0.8082 - val_loss: 0.4324 - val_accuracy: 0.8006
Epoch 28/100
313/313 [==============================] - 8s 25ms/step - loss: 0.4195 - accuracy: 0.8106 - val_loss: 0.4319 - val_accuracy: 0.8010
Epoch 29/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4177 - accuracy: 0.8101 - val_loss: 0.4313 - val_accuracy: 0.8006
Epoch 30/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4167 - accuracy: 0.8086 - val_loss: 0.4326 - val_accuracy: 0.7956
Epoch 31/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4166 - accuracy: 0.8093 - val_loss: 0.4318 - val_accuracy: 0.7948
Epoch 32/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4141 - accuracy: 0.8108 - val_loss: 0.4301 - val_accuracy: 0.7990
Epoch 33/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4160 - accuracy: 0.8087 - val_loss: 0.4304 - val_accuracy: 0.8048
Epoch 34/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4139 - accuracy: 0.8081 - val_loss: 0.4313 - val_accuracy: 0.8058
Epoch 35/100
313/313 [==============================] - 8s 24ms/step - loss: 0.4136 - accuracy: 0.8113 - val_loss: 0.4321 - val_accuracy: 0.8052
  • 이전과 완전히 동일한 조건으로 훈련을 했다.
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train', 'val'])
plt.show()

output_24_0

  • 이전보다 검증 손실과의 격차가 줄어든 모습을 볼 수 있다.
  • 밀집층이나 합성곱 층 처럼 순환층도 여러 개를 쌓지 않을 이유가 없다. 이어서 2개의 순환층을 연결한 모델을 훈련해보자.

4. 2개의 층을 연결하기

  • 순환층을 연결할 떄는 한 가지 주의할 점이 있다. 앞서 언급했던것 처럼 은닉 상태는 샘플의 마지막 타임스텝에 대한 은닉상태만 다음 층으로 전달한다. 하지만 순환층을 쌓게 되면 모든 순환층에 순차 데이터가 필요하다. 따라서 앞쪽의 순환층이 모든 타임스텝에 대한 은닉 상태를 출력해야만 한다.
  • 케라스의 순환층에서 모든 타임스텝의 은닉상태를 출력하려면 마지막을 제외한 다른 모든 순환층에서 return_sequences 매개변수를 True로 지정하면 된다.
model3 = keras.Sequential()
model3.add(keras.layers.Embedding(500, 16, input_length=100))
model3.add(keras.layers.LSTM(8, dropout=0.3, return_sequences=True))
model3.add(keras.layers.LSTM(8, dropout=0.3))
model3.add(keras.layers.Dense(1, activation='sigmoid'))
  • 2개의 LSTM층을 쌓았고 모두 드롭아웃을 0.3으로 지정했다. 그리고 첫 번째 LSTM클래스에는 return_sequences 매개변수를 True로 지정했다. summary()를 이용해 결과를 확인해보자.
model3.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, 100, 16)           8000      
_________________________________________________________________
lstm_2 (LSTM)                (None, 100, 8)            800       
_________________________________________________________________
lstm_3 (LSTM)                (None, 8)                 544       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 9         
=================================================================
Total params: 9,353
Trainable params: 9,353
Non-trainable params: 0
_________________________________________________________________
  • 첫 번쨰 LSTM층이 모든 타임스텝(100개)의 은닉 상태를 출력하기 때문에 출력 크기가 (None,100,8)로 표기된다. 마지막 타임스텝의 은닉 상태만 출력하기 때문에 (None,8)
  • 해당 모델을 앞과 같이 훈련을 진행해 보자.
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model3.compile(optimizer=rmsprop, loss='binary_crossentropy', 
              metrics=['accuracy'])

checkpoint_cb = keras.callbacks.ModelCheckpoint('best-dual-model.h5', 
                                                save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=3,
                                                  restore_best_weights=True)

history = model3.fit(train_seq, train_target, epochs=100, batch_size=64,
                   validation_data=(val_seq, val_target),
                   callbacks=[checkpoint_cb, early_stopping_cb])
Epoch 1/100
313/313 [==============================] - 18s 48ms/step - loss: 0.4864 - accuracy: 0.7742 - val_loss: 0.4853 - val_accuracy: 0.7736
Epoch 2/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4750 - accuracy: 0.7836 - val_loss: 0.4699 - val_accuracy: 0.7858
Epoch 3/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4668 - accuracy: 0.7833 - val_loss: 0.4622 - val_accuracy: 0.7858
Epoch 4/100
313/313 [==============================] - 15s 46ms/step - loss: 0.4610 - accuracy: 0.7894 - val_loss: 0.4597 - val_accuracy: 0.7896
Epoch 5/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4569 - accuracy: 0.7904 - val_loss: 0.4586 - val_accuracy: 0.7896
Epoch 6/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4526 - accuracy: 0.7934 - val_loss: 0.4535 - val_accuracy: 0.7934
Epoch 7/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4501 - accuracy: 0.7928 - val_loss: 0.4518 - val_accuracy: 0.7886
Epoch 8/100
313/313 [==============================] - 15s 46ms/step - loss: 0.4446 - accuracy: 0.7968 - val_loss: 0.4472 - val_accuracy: 0.7954
Epoch 9/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4423 - accuracy: 0.7966 - val_loss: 0.4484 - val_accuracy: 0.7880
Epoch 10/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4406 - accuracy: 0.7964 - val_loss: 0.4477 - val_accuracy: 0.7880
Epoch 11/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4401 - accuracy: 0.8005 - val_loss: 0.4437 - val_accuracy: 0.7922
Epoch 12/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4392 - accuracy: 0.8004 - val_loss: 0.4414 - val_accuracy: 0.7946
Epoch 13/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4373 - accuracy: 0.7989 - val_loss: 0.4457 - val_accuracy: 0.7884
Epoch 14/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4323 - accuracy: 0.8019 - val_loss: 0.4431 - val_accuracy: 0.7988
Epoch 15/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4334 - accuracy: 0.8023 - val_loss: 0.4390 - val_accuracy: 0.7988
Epoch 16/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4316 - accuracy: 0.8034 - val_loss: 0.4385 - val_accuracy: 0.7972
Epoch 17/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4297 - accuracy: 0.8018 - val_loss: 0.4377 - val_accuracy: 0.7952
Epoch 18/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4294 - accuracy: 0.8012 - val_loss: 0.4387 - val_accuracy: 0.7984
Epoch 19/100
313/313 [==============================] - 14s 46ms/step - loss: 0.4288 - accuracy: 0.8009 - val_loss: 0.4399 - val_accuracy: 0.8000
Epoch 20/100
313/313 [==============================] - 15s 46ms/step - loss: 0.4261 - accuracy: 0.8055 - val_loss: 0.4356 - val_accuracy: 0.7994
Epoch 21/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4275 - accuracy: 0.8056 - val_loss: 0.4368 - val_accuracy: 0.8014
Epoch 22/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4258 - accuracy: 0.8039 - val_loss: 0.4348 - val_accuracy: 0.7990
Epoch 23/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4230 - accuracy: 0.8063 - val_loss: 0.4341 - val_accuracy: 0.7994
Epoch 24/100
313/313 [==============================] - 15s 49ms/step - loss: 0.4239 - accuracy: 0.8062 - val_loss: 0.4364 - val_accuracy: 0.8030
Epoch 25/100
313/313 [==============================] - 15s 49ms/step - loss: 0.4219 - accuracy: 0.8065 - val_loss: 0.4338 - val_accuracy: 0.7994
Epoch 26/100
313/313 [==============================] - 17s 54ms/step - loss: 0.4229 - accuracy: 0.8048 - val_loss: 0.4331 - val_accuracy: 0.8004
Epoch 27/100
313/313 [==============================] - 16s 52ms/step - loss: 0.4215 - accuracy: 0.8084 - val_loss: 0.4349 - val_accuracy: 0.7974
Epoch 28/100
313/313 [==============================] - 15s 48ms/step - loss: 0.4206 - accuracy: 0.8070 - val_loss: 0.4341 - val_accuracy: 0.7970
Epoch 29/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4238 - accuracy: 0.8043 - val_loss: 0.4330 - val_accuracy: 0.8000
Epoch 30/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4207 - accuracy: 0.8084 - val_loss: 0.4330 - val_accuracy: 0.8032
Epoch 31/100
313/313 [==============================] - 15s 47ms/step - loss: 0.4191 - accuracy: 0.8076 - val_loss: 0.4345 - val_accuracy: 0.8054
Epoch 32/100
313/313 [==============================] - 15s 48ms/step - loss: 0.4194 - accuracy: 0.8089 - val_loss: 0.4307 - val_accuracy: 0.8024
Epoch 33/100
313/313 [==============================] - 17s 55ms/step - loss: 0.4165 - accuracy: 0.8105 - val_loss: 0.4319 - val_accuracy: 0.8032
Epoch 34/100
313/313 [==============================] - 18s 56ms/step - loss: 0.4173 - accuracy: 0.8082 - val_loss: 0.4360 - val_accuracy: 0.8034
Epoch 35/100
313/313 [==============================] - 17s 54ms/step - loss: 0.4156 - accuracy: 0.8114 - val_loss: 0.4315 - val_accuracy: 0.8030
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train', 'val'])
plt.show()

output_33_0

  • 비록 적게 수행되고 완벽히 과대적합을 제어하면서 손실을 낮춘 것이라 볼 수는 없지만 LSTM 셀을 사용한 훈련과 드롭아웃을 적용했고 2개의 층을 쌓은 순환 신경망을 만들어 봤다.
  • 다음은 또 다른 유명한 셀인 GRU셀에 대해 알아보자.

5. GRU 구조

  • GRU는 Gated Recurrent Unit의 약자이다. 이 셀은 LSTM의 간소화한 버전으로 생각할 수 있다.
  • LSTM처럼 셀 상태를 계싼하지 않고 은닉 상태 하나만 포함하고 있다.
  • GRU 셀에는 은닉 상태와 입력에 가중치를 곱하고 절편을 더하는 작은 셀이 3개 들어있다. 2개는 시그모이드 활성화 함수를 사용하고, 하나는 하이퍼볼릭 탄잰트 함수를 사용한다. 여기서도 은닉 상태와 입력에 곱해지는 가중치를 합쳐서 나타낸다.
  • GRU셀은 LSTM보다 가중치가 적기 때문에 계산량이 적지만 LSTM 못지않은 좋은 성능을 내는 것으로 알려져있다.

6. GRU 신경망 훈련하기

model4 = keras.Sequential()
model4.add(keras.layers.Embedding(500, 16, input_length=100))
model4.add(keras.layers.GRU(8))
model4.add(keras.layers.Dense(1, activation='sigmoid'))
model4.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_3 (Embedding)      (None, 100, 16)           8000      
_________________________________________________________________
gru (GRU)                    (None, 8)                 624       
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 9         
=================================================================
Total params: 8,633
Trainable params: 8,633
Non-trainable params: 0
_________________________________________________________________
  • 크게 달라진 것이 없다. GRU층의 모델 파라미터 개수를 계산해보면 GRU 셀에는 3개의 작은 셀이 존재한다. 각각의 작은 셀에는 입력과 은닉 상태에 곱하는 가중치와 절편이 존재한다. 입력에 곱하는 가중치는 16x8=128개 이고, 은닉 상태에 곱하는 가중치는 8x8=64개 이다. 그리고 절편은 뉴런마다 하나씩이므로 8개이다. 모두 더하면 200개로 이런 작은 셀이 3개 있으니 모두 600개의 모델 파라미터가 필요하다. 하지만 위의 결과를 보면 24개가 더 있는 것을 확인할 수 있다.
  • 텐서플로에 기본적으로 구현된 GRU셀의 계산은 앞선 계산법과 좀 다르다. 이전에는 셀의 출력과 은닉 상태가 곱해지는 것이 순서였다. 하지만 텐서플로에서는 은닉 상태가 먼저 가중치와 곱해진 다음 셀의 출력과 곱해진다. 그래서 별로도 나눠서 표기하기 때문에 작은 셀마다 하나씩 절편이 추가되고 8개의 뉴런이 있으므로 24개의 모델 파라미터가 더해진다.
  • 그럼 이제 훈련을 해보도록 하자.
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model4.compile(optimizer=rmsprop, loss='binary_crossentropy', 
              metrics=['accuracy'])

checkpoint_cb = keras.callbacks.ModelCheckpoint('best-GRU-model.h5', 
                                                save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=3,
                                                  restore_best_weights=True)

history = model4.fit(train_seq, train_target, epochs=100, batch_size=64,
                   validation_data=(val_seq, val_target),
                   callbacks=[checkpoint_cb, early_stopping_cb])
Epoch 1/100
313/313 [==============================] - 11s 30ms/step - loss: 0.6923 - accuracy: 0.5411 - val_loss: 0.6915 - val_accuracy: 0.5782
Epoch 2/100
313/313 [==============================] - 9s 28ms/step - loss: 0.6898 - accuracy: 0.5865 - val_loss: 0.6885 - val_accuracy: 0.6008
Epoch 3/100
313/313 [==============================] - 9s 28ms/step - loss: 0.6852 - accuracy: 0.6145 - val_loss: 0.6828 - val_accuracy: 0.6140
Epoch 4/100
313/313 [==============================] - 9s 28ms/step - loss: 0.6770 - accuracy: 0.6335 - val_loss: 0.6732 - val_accuracy: 0.6314
Epoch 5/100
313/313 [==============================] - 9s 28ms/step - loss: 0.6631 - accuracy: 0.6528 - val_loss: 0.6564 - val_accuracy: 0.6460
Epoch 6/100
313/313 [==============================] - 9s 28ms/step - loss: 0.6365 - accuracy: 0.6719 - val_loss: 0.6188 - val_accuracy: 0.6824
Epoch 7/100
313/313 [==============================] - 9s 28ms/step - loss: 0.5781 - accuracy: 0.7107 - val_loss: 0.5600 - val_accuracy: 0.7328
Epoch 8/100
313/313 [==============================] - 9s 29ms/step - loss: 0.5374 - accuracy: 0.7455 - val_loss: 0.5380 - val_accuracy: 0.7476
Epoch 9/100
313/313 [==============================] - 9s 28ms/step - loss: 0.5193 - accuracy: 0.7566 - val_loss: 0.5235 - val_accuracy: 0.7530
Epoch 10/100
313/313 [==============================] - 9s 28ms/step - loss: 0.5051 - accuracy: 0.7645 - val_loss: 0.5120 - val_accuracy: 0.7610
Epoch 11/100
313/313 [==============================] - 9s 30ms/step - loss: 0.4940 - accuracy: 0.7717 - val_loss: 0.5029 - val_accuracy: 0.7626
Epoch 12/100
313/313 [==============================] - 9s 30ms/step - loss: 0.4841 - accuracy: 0.7773 - val_loss: 0.4944 - val_accuracy: 0.7696
Epoch 13/100
313/313 [==============================] - 9s 30ms/step - loss: 0.4761 - accuracy: 0.7822 - val_loss: 0.4871 - val_accuracy: 0.7742
Epoch 14/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4694 - accuracy: 0.7872 - val_loss: 0.4809 - val_accuracy: 0.7752
Epoch 15/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4635 - accuracy: 0.7882 - val_loss: 0.4782 - val_accuracy: 0.7740
Epoch 16/100
313/313 [==============================] - 9s 29ms/step - loss: 0.4582 - accuracy: 0.7921 - val_loss: 0.4715 - val_accuracy: 0.7838
Epoch 17/100
313/313 [==============================] - 9s 29ms/step - loss: 0.4539 - accuracy: 0.7948 - val_loss: 0.4689 - val_accuracy: 0.7794
Epoch 18/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4491 - accuracy: 0.7963 - val_loss: 0.4673 - val_accuracy: 0.7760
Epoch 19/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4455 - accuracy: 0.7994 - val_loss: 0.4659 - val_accuracy: 0.7766
Epoch 20/100
313/313 [==============================] - 9s 30ms/step - loss: 0.4422 - accuracy: 0.8002 - val_loss: 0.4595 - val_accuracy: 0.7852
Epoch 21/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4388 - accuracy: 0.8019 - val_loss: 0.4574 - val_accuracy: 0.7878
Epoch 22/100
313/313 [==============================] - 9s 30ms/step - loss: 0.4359 - accuracy: 0.8042 - val_loss: 0.4561 - val_accuracy: 0.7902
Epoch 23/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4339 - accuracy: 0.8049 - val_loss: 0.4572 - val_accuracy: 0.7898
Epoch 24/100
313/313 [==============================] - 10s 32ms/step - loss: 0.4318 - accuracy: 0.8076 - val_loss: 0.4541 - val_accuracy: 0.7872
Epoch 25/100
313/313 [==============================] - 10s 30ms/step - loss: 0.4302 - accuracy: 0.8073 - val_loss: 0.4536 - val_accuracy: 0.7916
Epoch 26/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4278 - accuracy: 0.8098 - val_loss: 0.4538 - val_accuracy: 0.7848
Epoch 27/100
313/313 [==============================] - 9s 29ms/step - loss: 0.4271 - accuracy: 0.8080 - val_loss: 0.4499 - val_accuracy: 0.7916
Epoch 28/100
313/313 [==============================] - 10s 31ms/step - loss: 0.4255 - accuracy: 0.8112 - val_loss: 0.4493 - val_accuracy: 0.7920
Epoch 29/100
313/313 [==============================] - 9s 29ms/step - loss: 0.4240 - accuracy: 0.8108 - val_loss: 0.4507 - val_accuracy: 0.7918
Epoch 30/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4225 - accuracy: 0.8117 - val_loss: 0.4488 - val_accuracy: 0.7916
Epoch 31/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4215 - accuracy: 0.8132 - val_loss: 0.4484 - val_accuracy: 0.7920
Epoch 32/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4204 - accuracy: 0.8133 - val_loss: 0.4471 - val_accuracy: 0.7940
Epoch 33/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4197 - accuracy: 0.8138 - val_loss: 0.4463 - val_accuracy: 0.7950
Epoch 34/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4192 - accuracy: 0.8136 - val_loss: 0.4463 - val_accuracy: 0.7946
Epoch 35/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4180 - accuracy: 0.8136 - val_loss: 0.4452 - val_accuracy: 0.7948
Epoch 36/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4171 - accuracy: 0.8156 - val_loss: 0.4457 - val_accuracy: 0.7922
Epoch 37/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4163 - accuracy: 0.8156 - val_loss: 0.4439 - val_accuracy: 0.7958
Epoch 38/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4157 - accuracy: 0.8140 - val_loss: 0.4495 - val_accuracy: 0.7950
Epoch 39/100
313/313 [==============================] - 9s 27ms/step - loss: 0.4152 - accuracy: 0.8140 - val_loss: 0.4431 - val_accuracy: 0.7966
Epoch 40/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4146 - accuracy: 0.8163 - val_loss: 0.4447 - val_accuracy: 0.7930
Epoch 41/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4141 - accuracy: 0.8166 - val_loss: 0.4424 - val_accuracy: 0.7952
Epoch 42/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4135 - accuracy: 0.8159 - val_loss: 0.4428 - val_accuracy: 0.7938
Epoch 43/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4130 - accuracy: 0.8152 - val_loss: 0.4435 - val_accuracy: 0.7922
Epoch 44/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4123 - accuracy: 0.8170 - val_loss: 0.4408 - val_accuracy: 0.7956
Epoch 45/100
313/313 [==============================] - 9s 29ms/step - loss: 0.4121 - accuracy: 0.8166 - val_loss: 0.4412 - val_accuracy: 0.7988
Epoch 46/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4110 - accuracy: 0.8179 - val_loss: 0.4437 - val_accuracy: 0.7956
Epoch 47/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4109 - accuracy: 0.8171 - val_loss: 0.4400 - val_accuracy: 0.7962
Epoch 48/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4103 - accuracy: 0.8185 - val_loss: 0.4393 - val_accuracy: 0.7964
Epoch 49/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4098 - accuracy: 0.8184 - val_loss: 0.4393 - val_accuracy: 0.7982
Epoch 50/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4096 - accuracy: 0.8175 - val_loss: 0.4383 - val_accuracy: 0.7986
Epoch 51/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4093 - accuracy: 0.8173 - val_loss: 0.4379 - val_accuracy: 0.7998
Epoch 52/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4089 - accuracy: 0.8181 - val_loss: 0.4375 - val_accuracy: 0.7992
Epoch 53/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4084 - accuracy: 0.8172 - val_loss: 0.4366 - val_accuracy: 0.7988
Epoch 54/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4074 - accuracy: 0.8184 - val_loss: 0.4368 - val_accuracy: 0.8000
Epoch 55/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4073 - accuracy: 0.8194 - val_loss: 0.4421 - val_accuracy: 0.7932
Epoch 56/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4073 - accuracy: 0.8166 - val_loss: 0.4358 - val_accuracy: 0.7984
Epoch 57/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4067 - accuracy: 0.8194 - val_loss: 0.4351 - val_accuracy: 0.7996
Epoch 58/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4065 - accuracy: 0.8194 - val_loss: 0.4362 - val_accuracy: 0.7996
Epoch 59/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4061 - accuracy: 0.8196 - val_loss: 0.4362 - val_accuracy: 0.8008
Epoch 60/100
313/313 [==============================] - 9s 28ms/step - loss: 0.4058 - accuracy: 0.8169 - val_loss: 0.4383 - val_accuracy: 0.7992
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train', 'val'])
plt.show()

output_42_0

  • 드롭아웃을 사용하지 않았기 때문에 훈련 손실과 검증 손실 사이에 차이가 있지만 훈련 과정이 비교적 잘 수렴되고 있는 것을 확인할 수 있다.

7. 마치며

  • 드디어 머신러닝 전체적인 내용과 딥러닝의 기초 전반적인 부분을 간략하게 살펴볼 수 있었다. 아직 갈길은 멀다. 그저 큰 바다에 발만 담궈본 정도기 때문이다.
  • 이 책이 새로 머신러닝이나 딥러닝 입문을 하는 초심자들에게 추천하고 싶다. 혹은 기본적인 내용을 다시 돌아볼 필요가 있는 사람들에게도 좋은 선택이 될 수 있다고 생각한다.
  • 이제는 핸즈온 머신러닝을 가지고 조금더 심도있게 들어가보려 한다.

Categories:

Updated:

Leave a comment