알고리즘 풀이/백준

[ 백준 5557 ] 1학년 - Python

12.tka 2020. 7. 4. 21:01
728x90

문제 보기

이 문제는 DP 문제이다.

 

처음에는 완전 탐색(BFS)으로 접근을 하였다.

전체 반복 횟수가 최대 100이라는 것만 생각하였으며 반복하는 동안 큐의 삽입, 삭제 연산 횟수까지 고려하지 못했었다.

 

문제점을 찾은 후 시간 복잡도를 줄이기 위하여 DP를 생각하였다. 이전의 숫자 정보를 계속해서 가지고 있어야 하기 때문에 1차원 DP로는 불가능하였고 2차원 DP로 구현하였다.

 

dp[i][j] = i번째 숫자까지 연산을 진행했을 때 j값을 나타낼 수 있는 경우의 수

 

위 식을 생각한 후 코드를 구현하는 것은 어렵지 않았다. 

 

1. 입력받은 수를 num_list의 리스트에 대입

2. dp의 크기 (N - 1) * 21로 선언한 후 0으로 초기화

   - 행의 크기가 N - 1인 이유는 최종적인 값을 계산하기 위한 과정은 N - 1까지이기 때문이다.

   - 열의 크기가 21인 이유는 가능한 수가 0 이상 20 이하이기 때문이다.

3. dp[0][num_list[0]] = 1을 대입한다.

4. 2중 for문을 돌면서 dp[i - 1][j]의 값이 0인지 확인한다.

   - 0이 아니라면 해당 dp[i -1][j]의 인덱스(j)를 활용하여 현재 값을 추가할 수 있는지 판단한다.

   - 추가할 수 있다면 이전의 dp[i - 1][j] 값을 더해준다.

 

코드

if __name__ == "__main__":

    N = int(input())  # 숫자의 개수
    num_list = list(map(int, input().split()))  # 숫자 리스트

    # 2차원 dp 초기화
    dp = []
    for i in range(N - 1):
        dp.append([0 for _ in range(21)])

    dp[0][num_list[0]] = 1
    for i in range(1, N - 1):
        for j in range(21):
            if dp[i - 1][j] != 0:
                pre_value = j
                next_value = num_list[i]
                if 0 <= pre_value + next_value <= 20:
                    dp[i][pre_value + next_value] += dp[i - 1][pre_value]
                if 0 <= pre_value - next_value <= 20:
                    dp[i][pre_value - next_value] += dp[i - 1][pre_value]

    print(dp[N - 2][num_list[N - 1]])

 

추가로 시간 초과가 발생한 bfs코드는 아래와 같습니다.

import sys
from collections import deque
sys.setrecursionlimit(10**6)


def bfs():
    answer = num_list[N - 1]  # 만들어야하는 최종 값
    q = deque([num_list[0]])  # 현재 값, 진행한 인덱스 값

    for i in range(1, N - 1):
        value = num_list[i]
        temp = []
        while q:  # 큐가 빌때까지
            current_val = q.popleft()
            if 0 <= current_val + value <= 20:
                temp.append(current_val + value)
            if 0 <= current_val - value <= 20:
                temp.append(current_val - value)
        q.extend(temp)

    result = list(filter(lambda x: x == answer, q))
    return len(result)


if __name__ == "__main__":
    N = int(input())  # 숫자의 개수
    num_list = list(map(int, input().split()))  # 숫자 리스트

    print(bfs())

 

728x90