[BOJ] 백준 10986번 문제 풀이 (Python)
Updated:
1. 풀이
연속된 부분합 중에서 M으로 나누어 떨어지는 부분합의 갯수를 찾아야한다.
투포인터, DP 등 다양한 방법을 고민해봤지만 결국 부분합 자체만을 이용해서 풀었는데 방법은 매우 단순하다. 반복문을 이용해 원소를 하나씩 접근하면서 i
번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구했다.
예제를 통해 어떤 방법으로 구했는지 살펴보자
input 1 2 3 1 2
pSum 1 3 6 7 9
예제 입력의 부분합을 모두 구해pSum
에 모두 저장하였다. 어차피 M으로 나누어 떨어지는 부분합의 갯수를 모두 찾는 것이므로 각 부분합을 M으로 나눴을 때 나머지가 각각 몇 개 존재하는 지 count에 저장하였다.
이제 0번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구해보자. M으로 나누어 떨어지는 갯수는 count[0]
이므로 3이다.
input 1 2 3 1 2
pSum 1 3 6 7 9
pSum 1 0 0 1 0 (M으로 나눴을 때 나머지)
0 1 2
count 3 2 0
result 3
1번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구해야 하는데 주의할 점 2개 있다.
pSum[0]
(=1, 나머지 1)이 제외되므로count[1]
이 1 줄어든다.pSum[0]
이 제외되므로 전체적으로 각 부분합이 1만큼 줄어들게된다. 따라서 기존count[0]
은 나머지가 2인 것을 가리키게 되고count[1]
은 나머지가 0인 것을 가리키게 된다.count[2]
는 나머지가 1인 것을 가리키게 된다. (왼쪽으로 1만큼 이동)
input 2 3 1 2
pSum 3 6 7 9
pSum 0 0 1 0 (M으로 나눴을 때 나머지)
0(→2) 1(→0) 2(→1)
count 3 1 0
result 4(=3+1)
2번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.
pSum[1]
(=3, 나머지 0)이 제외되므로count[0]
이 1 줄어든다.pSum[1]
이 제외되므로 전체적으로 각 부분합이 2(=나머지 2)만큼 줄어들게된다. 따라서 다시count[0]
이 나머지가 0인 것을 가리키게 된다. (왼쪽으로 2만큼 이동)
input 1 2 3 1 2
pSum 6 7 9
pSum 0 1 0 (M으로 나눴을 때 나머지)
0(→0) 1(→1) 2(→2)
count 2 1 0
result 6(=3+1+2)
3번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.
pSum[2]
(=6, 나머지 0)이 제외되므로count[0]
이 1 줄어든다.pSum[2]
이 제외되므로 전체적으로 각 부분합이 3(=나머지 0)만큼 줄어들게된다. 기존 그대로count[0]
이 나머지가 0인 것을 가리키게 된다. (왼쪽으로 0만큼 이동)
input 1 2 3 1 2
pSum 7 9
pSum 1 0 (M으로 나눴을 때 나머지)
0(→0) 1(→1) 2(→2)
count 1 1 0
result 7(=3+1+2+1)
4번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.
pSum[3]
(=7, 나머지 1)이 제외되므로count[1]
이 1 줄어든다.pSum[3]
이 제외되므로 전체적으로 각 부분합이 1(=나머지 1)만큼 줄어들게된다. 따라서 기존count[0]
은 나머지가 2인 것을 가리키게 되고count[1]
은 나머지가 0인 것을 가리키게 된다.count[2]
는 나머지가 1인 것을 가리키게 된다. (왼쪽으로 1만큼 이동)
input 1 2 3 1 2
pSum 9
pSum 0 (M으로 나눴을 때 나머지)
0(→2) 1(→0) 2(→1)
count 1 0 0
result 7(=3+1+2+1+0)
최종적으로 코드를 통해 위 과정을 간략히 표현한다면 다음과 같아진다.
result = 0
# IndexError를 방지하기 위해서 시작을 1로 변경 (범위: 1 ~ n+1)
for i in range(1,n+1):
result += count[pSum[i-1]]
count[pSum[i]] -= 1
2. 코드
import sys
def input(): return sys.stdin.readline().rstrip()
n,m = map(int,input().split())
nums = list(map(int,input().split()))
pSum = [0] * (n+1)
# IndexError를 방지하기 위해서 시작을 1로 변경 (범위: 1 ~ n+1)
for i in range(1,n+1):
pSum[i] = pSum[i-1] + nums[i-1]
nums = list(map(lambda x: x % m,nums))
pSum = list(map(lambda x: x % m, pSum))
count = [0] * m
for p_sum in pSum[1:]:
count[p_sum] += 1
result = 0
for i in range(1,n+1):
result += count[pSum[i-1]]
count[pSum[i]] -= 1
print(result)
Leave a comment