Data Engineering/머신러닝

손글씨 숫자 인식(MNIST) - Python

12.tka 2020. 6. 1. 16:14
728x90

이번 글에서는 Python과 MNIST 데이터셋을 이용해서 손글씨 숫자를 인식하고자 한다.

MNIST 데이터셋

MNIST 데이터셋은 손글씨 숫자 이미지 집합이다. 기계학습 분야에서 아주 유명한 데이터셋으로, 간단한 실험부터 논문으로 발표되는 연구까지 다양한 곳에서 이용하고 있다.

 

0부터 9까지 숫자 이미지로 구성되며 훈련 이미지 60,000장, 시험 이미지가 10,000장으로 구성된다. 훈련 이미지를 사용해서 모델을 학습하고, 학습한 모델로 시험 이미지들을 얼마나 정확하게 분류하는지 추론한다.

 

각각의 이미지 데이터는 28 x 28 크기의 회색조 이미지(1채널)이며, 각 픽셀은 0에서 255까지의 값을 가진다. 또한 각 이미지에는 '7', '2', '1'과 같이 그 이미지가 실제 의미하는 숫자가 레이블로 붙어있다.


MNIST 데이터셋  읽어오기

데이터셋

위 링크에 접속하면 나오는 dataset 폴더에 MNIST와 관련된 코드들이 있다.

그 중 mnist.py에 있는 load_mnist() 함수를 이용할 것이다. dataset 폴더를 다운로드 한 후 작성하고자 하는 코드의 부모디렉토리에 위치시킨다.

 

import sys, os
sys.path.append(os.pardir) # 부모 디렉터리의 파일을 가져올 수 있도록 설정
from dataset.mnist import load_mnist

# (훈련 이미지, 훈련 레이블), (시험 이미지, 시험 레이블)
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

# 각 데이터의 형상 출력
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)

 

실행 결과

위 코드를 실행시키면 MNIST 데이터를 (훈련 이미지, 훈련 레이블), (시험 이미지, 시험 레이블) 형식으로 가져온다. 

load_mnist 함수에서 normalize, flatten, one_hot_label 세 가지를 설정할 수 있다.

 

normalize(정규화)는 입력 이미지의 픽셀 값을 0.0 ~ 1.0 사이의 값으로 정규화할지를 정한다. False로 설정하면 입력 이미지의 픽셀 원래 값 그대로 유지한다.

 

flatten은 입력 이미지를 평탄하게, 즉 1차원 배열로 만들지를 정한다. False로 설정하면 입력 이미지를 원래 값 그대로 유지하며, True로 설정하면 1 * 28 * 28의 3차원 배열을 784개로 이뤄진 1차원 배열로 저장한다.

 

one_hot_label은 레이블을 원-핫 인코딩(one-hot encoding) 형태로 만들지를 정한다. 원-핫 인코딩이란, 예를 들어 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] 처럼 정답을 뜻하는 원소만 1이고(hot하고) 나머지는 모두 0인 배열이다.


MNIST 이미지 불러오기

이번에는 PIL(python Image Library) 모듈을 이용해서 MNIST 이미지를 화면으로 불러오고자 한다.

 

import sys, os
import numpy as np

sys.path.append(os.pardir)  # 부모 디렉터리의 파일을 가져올 수 있도록 설정
from dataset.mnist import load_mnist
from PIL import Image


def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()


# (훈련 이미지, 훈련 레이블), (시험 이미지, 시험 레이블)
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

img = x_train[0]
label = t_train[0]
print(label)

print(img.shape)
img = img.reshape(28, 28)  # 크기 수정
print(img.shape)
img_show(img)

load_mnist로 MNIST를 읽어올 때 flatten=True로 설정하였기 때문에 이미지로 표시할 때 원래 형상인 28 x 28로 수정하였다. fromarray를 통해서 넘파이로 저장된 이미지 데이터를 PIL용 데이터 객체로 변환한다.

 

실행 결과

 

실행 결과 첫 번째 데이터 5를 성공적으로 불러왔다.

728x90