Updated:

1. 풀이

연속된 부분합 중에서 M으로 나누어 떨어지는 부분합의 갯수를 찾아야한다.

투포인터, DP 등 다양한 방법을 고민해봤지만 결국 부분합 자체만을 이용해서 풀었는데 방법은 매우 단순하다. 반복문을 이용해 원소를 하나씩 접근하면서 i번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구했다.

예제를 통해 어떤 방법으로 구했는지 살펴보자

input 1 2 3 1 2
pSum  1 3 6 7 9

예제 입력의 부분합을 모두 구해pSum에 모두 저장하였다. 어차피 M으로 나누어 떨어지는 부분합의 갯수를 모두 찾는 것이므로 각 부분합을 M으로 나눴을 때 나머지가 각각 몇 개 존재하는 지 count에 저장하였다.

이제 0번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구해보자. M으로 나누어 떨어지는 갯수는 count[0]이므로 3이다.

input 1 2 3 1 2
pSum  1 3 6 7 9
pSum  1 0 0 1 0 (M으로 나눴을 때 나머지)

      0 1 2
count 3 2 0

result 3

1번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수를 구해야 하는데 주의할 점 2개 있다.

  • pSum[0](=1, 나머지 1)이 제외되므로 count[1]이 1 줄어든다.
  • pSum[0]이 제외되므로 전체적으로 각 부분합이 1만큼 줄어들게된다. 따라서 기존 count[0]은 나머지가 2인 것을 가리키게 되고 count[1]은 나머지가 0인 것을 가리키게 된다. count[2]는 나머지가 1인 것을 가리키게 된다. (왼쪽으로 1만큼 이동)
input   2 3 1 2
pSum    3 6 7 9
pSum    0 0 1 0 (M으로 나눴을 때 나머지)

      0(→2) 1(→0) 2(→1)
count 3     1     0

result 4(=3+1)

2번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.

  • pSum[1](=3, 나머지 0)이 제외되므로 count[0]이 1 줄어든다.
  • pSum[1]이 제외되므로 전체적으로 각 부분합이 2(=나머지 2)만큼 줄어들게된다. 따라서 다시 count[0]이 나머지가 0인 것을 가리키게 된다. (왼쪽으로 2만큼 이동)
input 1 2 3 1 2
pSum      6 7 9
pSum      0 1 0 (M으로 나눴을 때 나머지)

      0(→0) 1(→1) 2(→2)
count 2     1     0

result 6(=3+1+2)

3번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.

  • pSum[2](=6, 나머지 0)이 제외되므로 count[0]이 1 줄어든다.
  • pSum[2]이 제외되므로 전체적으로 각 부분합이 3(=나머지 0)만큼 줄어들게된다. 기존 그대로 count[0]이 나머지가 0인 것을 가리키게 된다. (왼쪽으로 0만큼 이동)
input 1 2 3 1 2
pSum        7 9
pSum        1 0 (M으로 나눴을 때 나머지)

      0(→0) 1(→1) 2(→2)
count 1     1     0

result 7(=3+1+2+1)

4번째 원소부터 시작하는 부분합 중에서 M으로 나누어 떨어지는 갯수도 구해보자.

  • pSum[3](=7, 나머지 1)이 제외되므로 count[1]이 1 줄어든다.
  • pSum[3]이 제외되므로 전체적으로 각 부분합이 1(=나머지 1)만큼 줄어들게된다. 따라서 기존 count[0]은 나머지가 2인 것을 가리키게 되고 count[1]은 나머지가 0인 것을 가리키게 된다. count[2]는 나머지가 1인 것을 가리키게 된다. (왼쪽으로 1만큼 이동)
input 1 2 3 1 2
pSum          9
pSum          0 (M으로 나눴을 때 나머지)

      0(→2) 1(→0) 2(→1)
count 1     0     0

result 7(=3+1+2+1+0)

최종적으로 코드를 통해 위 과정을 간략히 표현한다면 다음과 같아진다.

result = 0
# IndexError를 방지하기 위해서 시작을 1로 변경 (범위: 1 ~ n+1)
for i in range(1,n+1):
    result += count[pSum[i-1]]
    count[pSum[i]] -= 1

2. 코드

import sys
def input(): return sys.stdin.readline().rstrip()

n,m = map(int,input().split())
nums = list(map(int,input().split()))
pSum = [0] * (n+1)

# IndexError를 방지하기 위해서 시작을 1로 변경 (범위: 1 ~ n+1)
for i in range(1,n+1):
    pSum[i] = pSum[i-1] + nums[i-1]
nums = list(map(lambda x: x % m,nums))
pSum = list(map(lambda x: x % m, pSum))
count = [0] * m
for p_sum in pSum[1:]:
    count[p_sum] += 1
    
result = 0
for i in range(1,n+1):
    result += count[pSum[i-1]]
    count[pSum[i]] -= 1
print(result)

Leave a comment