[이론] 냅색, 배낭 알고리즘

알고리즘 설명

냅색 문제는 dp 문제의 일종입니다. 첫 번째 문제인 벼락치기는 weight에 대한 value를 최대로 하는 문제입니다. 두 번째 문제는 weight를 채우는 경우의 수를 구하는 문제입니다. value를 구하는 것과 경우의 수를 구하는 것은 구현의 문제이고, 중요한 것은 같은 weight의 값을 중복해서 선택할 수 있는지에 대한 여부입니다. 벼락치기 문제는 한 단원에 한 문제를 출제한다라는 조건으로 중복 선택을 막았고, 동전문제는 중복 선택을 해야하는 문제입니다.

두 가지 모두 마지막에 $W_{x}$를 선택했을 때의 경우를 구한다는 점에서 동일하지만, 중복 선택 여부에 따라서 dp의 적용 방법이 다릅니다.

14728 벼락치기

벼락치기 문제를 기준으로 설명하겠습니다. 골드5이지만 정답 코드가 굉장히 짧기 때문입니다.

예제 입력으로 주어지는 값들을 아래와 같은 변수에 입력받는다고 가정하겠습니다.

1
2
3
4
3=n 310=t
50=k 40=s
100=k 70=s
200=k 150=s

dp[i] = i시간으로 얻을 수 있는 최대 점수라고 하겠습니다.

  1. dp[310] = max(dp[310], dp[260]+40)
  2. dp[310] = max(dp[310], dp[210]+70)
  3. dp[310] = max(dp[310], dp[110]+150)

구하고자하는 dp[310] 값은 위와 같은 3번의 비교를 통해서 구합니다. 가장 마지막에 공부하는 (k,s)쌍이 어떤 것인지에 따라서 결정이 됩니다.

다시 적어보면 아래와 같습니다.

  1. dp[310] = max(dp[310], dp[310 -50]+40) : 마지막에 50으로 채워지는 경우
  2. dp[310] = max(dp[310], dp[310 - 100]+70) : 마지막에 100으로 채워지는 경우
  3. dp[310] = max(dp[310], dp[310 - 150]+150) : 마지막에 200으로 채워지는 경우

코드로 표현하면 아래와 같습니다.

1
2
for i in range(3):
    dp[310] = max(dp[310], dp[310-k[i]]+s[i])

주의할 점은 dp[310]을 구하기 위해서 필요한 dp[250],dp[210],d[110]을 미리 구해놓아야 위 공식만으로 dp[310]을 구할 수 있습니다. 이를 일반화한 공식은 아래와 같습니다.

  • dp[i] = max(dp[i], dp[i-k[j]]+s[j])
  • dp[i] = max(기존값,마지막으로 j번째 값을 선택했을 경우의 값)

아래의 표는 벼락치기 문제와 아주 똑같은 평범한 배낭의 예제 입력으로 만든 표입니다. 아래 코드의 알고리즘과 비교해서 보시면 dp[i]가 최적이 되는 과정을 이해할 수 있습니다.

  • dp[1]+13=13 : 마지막으로 (6,13)을 선택하는 것이 최적
  • dp[0]+13=13 : 마지막으로 (6,13)을 선택하는 것이 최적
  • dp[3] + 8 = 8, max(13,8)=13 : 마지막으로 (4,8)을 선택할 수도 있지만, (6,13)을 선택하는 것이 최적
  • dp[2] + 8 = 8, max(13,8)=13 : 마지막으로 (4,8)을 선택할 수도 있지만, (6,13)을 선택하는 것이 최적
  dp[0] dp[1] dp[2] dp[3] dp[4] dp[5] dp[6] dp[7]
초기값 0 0 0 0 0 0 0 0
(6,13) 0 0 0 0 0 0 dp[0]+13=13 dp[1]+13=13
(4,8) 0 0 0 0 dp[0] + 8 = 8
max(0,8)=8
dp[1] + 8 = 8
max(0,8)=8
dp[2] + 8 = 8
max(13,8)=13
dp[3] + 8 = 8
max(13,8)=13
(3,6) 0 0 0 dp[0] + 6 = 6
max(8,6)=6
dp[1] + 6 = 0
max(8,6)=8
dp[2] + 6 = 0
max(8,0)=8
dp[3] + 6 = 0
max(13,0)=13
dp[4] + 6 = 14
max(13,14)=14
(5,12) 0 0 0 0 0 dp[0] + 12 = 12
max(8,12)=12
dp[1] + 12 = 12
max(13,12)=13
dp[2] + 12 = 12
max(14,12)=12
dp[i] 0 0 0 6 8 12 13 14

잘못된 풀이

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 잘못된 것을 보여주기 위해 dp길이를 32로 제한했습니다.
import sys
from math import sqrt, ceil
from bisect import bisect_left, bisect_right
from operator import itemgetter
from collections import Counter, deque
from copy import deepcopy

sys.stdin = open("input.txt", "r")

dp = [0] * 32
tc, time = map(int, sys.stdin.readline().split())
for _ in range(tc):
    k, s = map(int, sys.stdin.readline().split())
    # for i in range(time, k-1, -1):
    for i in range(k, time + 1):
        dp[i] = max(dp[i], dp[i - k] + s)
    print(dp)


print(dp[time])

'''
3 31
5 4
10 7
20 15

앞에서부터 (누적 가치를) 채우면 같은 값을 여러번 선택하는 경우가 발생한다. 
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 20, 20, 20, 20, 20, 24, 24]
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 20, 20, 20, 20, 20, 24, 24]
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 20, 20, 20, 20, 20, 24, 24]
'''

올바른 풀이

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import sys
from math import sqrt, ceil
from bisect import bisect_left, bisect_right
from operator import itemgetter
from collections import Counter, deque
from copy import deepcopy

sys.stdin = open("input.txt", "r")

dp = [0] * 10010
tc, time = map(int, sys.stdin.readline().split())
for _ in range(tc):
    k, s = map(int, sys.stdin.readline().split())
    for i in range(time, k-1, -1):
        dp[i] = max(dp[i], dp[i - k] + s)
    print(dp)


print(dp[time])
'''
3 31
5 4
10 7
20 15

뒤에서부터 (누적 가치를) 채우면 같은 값을 여러번 선택하는 경우가 없다. 
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11]
[0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 15, 15, 15, 15, 15, 19, 19, 19, 19, 19, 22, 22]
'''

9084 동전

동전

알고리즘

dp[i] : i원을 채우는 모든 경우의 수

동전이 1원짜리, 2원짜리 두 개가 있을 때 10원을 채우는 과정을 생각해보겠습니다. 0원을 동전으로 채우는 방법의 수를 1개로 정의합니다. 이후에 dp[i] (i원을 동전으로 채우는 방법)을 구하기 위한 방법은 두 가지로 나뉩니다.

  1. 0원을 채우는 방법은 아무 동전도 안쓰는 1가지뿐 : dp[0] = 1
  2. i-1원을 이미 채운 상태에서 1원짜리 동전을 더해서 i원을 채우는 경우 dp[i] += dp[i-1]
  3. i-2원을 이미 채운 상태에서 2원짜리 동전을 더해서 i원을 채우는 경우 dp[i] += dp[i-2]

그런데 주의할 점은 1원짜리로 채울 수 있는 모든 경우를 구한 뒤, 여기에 2원짜리를 추가해서 채울 수 있는 경우의 수를 더해야합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 올바른 경우
for coin in coins:
    for i in range(m - coin + 1):
        dp[i + coin] += dp[i]
# 올바른 경우 2
for c in coin:
    for m in range(1, M + 1):
        if m - c >= 0:
            ans[m] += ans[m-c]

# 틀린 풀이
for i in range(m):
    for coin in coins:
        if i + coin <= m:
            dp[i + coin] += dp[i]

올바른 경우를 카운트 했을 때는 아래와 같은 표로 생각할 수 있습니다.

dp[i] 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
초기값 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
5 1 1 1 1 1 2 2 2 2 2 3 3 3 3 3 4 4 4 4 4 5
10 1 1 1 1 1 2 2 2 2 2 4 4 4 4 4 6 6 6 6 6 9

틀린 풀이로 작성하는 경우 중복된 카운트가 생깁니다.

아래 각 cell의 (a,b,c)는 1원짜리가 a개, 5원짜리가 b개, 10원짜리가 c개인 경우를 나타냅니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 틀린 풀이
for i in range(14):
    for coin in [1,5,10]:
        if i + coin <= m:
            dp[i + coin] += dp[i]
'''
실행결과 print(dp)
[1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]
[1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
[1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 2, (*3), 1, 1, 1, 3, 1, 1, 1, 1]
'''

위 실행 결과에서 (*3) 위치가 아래 표에서 ????의 위치입니다. 6을 동전으로 나타내는 방법은 두 가지 방법뿐인데 3이됩니다. 그 이유는 dp[5]에는 (0,1,0)과 (5,0,0)가 모두 들어있고, 여기에 1원이 추가되면서 (1,1,0)과 (6,0,0)이 더해지는데 이미 (1,1,0)은 존재하기 때문에 중복이 발생하고 (*3)이 생긴 것입니다.

dp[i] 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
초기값 0,0,0                            
(1,5,10)   1,0,0       0,1,0         0,0,1        
      2,0,0       1,1,0         1,0,1      
        3,0,0       2,1,0         2,1,1    
          4,0,0       3,1,0         3,1,1  
            5,0,0       4,1,0         4,1,1
              ?????                

정답 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
from math import sqrt, ceil
from bisect import bisect_left, bisect_right
from operator import itemgetter
from collections import Counter, deque
from copy import deepcopy

tc = int(input())
for _ in range(tc):
    n = int(input())
    coins = list(map(int, sys.stdin.readline().split()))
    m = int(input())
    dp = [0] * (m+1)
    dp[0] = 1
    for coin in coins:
        for i in range(m - coin + 1):
            dp[i + coin] += dp[i]
        # print(dp)
    print(dp[m])

'''
3
2
1 2
20
3
1 5 10
20
2
5 7
22
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11]
11
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5]
[1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 9]
9
[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0]
[1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1]
1
'''

Reference

Success Notice: 수고하셨습니다. :+1:

Leave a comment