관리 메뉴

프로그램 소스

Keras를 이용한 CNN, 손글씨 인식 소스 코드 본문

빅데이터/머신 러닝 with Python

Keras를 이용한 CNN, 손글씨 인식 소스 코드

Only 프로그램 소스 언제나휴일 2020. 8. 17. 18:03
import keras
from keras import models, layers
from keras import backend
from keras import datasets

batch_size = 128
epochs = 10
num_classes = 10

#데이터 로딩
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
img_rows, img_cols = x_train.shape[1:]

if backend.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

#모델 개체 생성
model = models.Sequential()
model.add(layers.Conv2D(32, kernel_size=(3, 3),
                       activation='relu',
                       input_shape=input_shape))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.25))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer='adam',
             #optimizer='rmsprop',
             metrics=['accuracy'])
#학습
hist = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_split=0.2)
#평가
score = model.evaluate(x_test, y_test)
print()
print('Test loss:', score[0])
print('Test accuracy:', score[1])

#평가 결과 도식화
import matplotlib.pyplot as plt
fig, loss_ax = plt.subplots(figsize=(10, 5))
acc_ax = loss_ax.twinx()

loss_ax.plot(hist.history['loss'], 'y', label='train loss')
loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')
acc_ax.plot(hist.history['accuracy'], 'b', label='train acc')
acc_ax.plot(hist.history['val_accuracy'], 'g', label='val acc')

loss_ax.set_xlabel('epoch')
loss_ax.set_ylabel('loss')
acc_ax.set_ylabel('accuray')
loss_ax.legend(loc='upper left')
acc_ax.legend(loc='lower left')

plt.show()
0 Comments
댓글쓰기 폼