IT/ps

백준 15661번 링크와 스타트

u149_cinderella 2025. 2. 8. 00:47

 

 

(링크스타또)

 

 

이번문제와 저번문제의 다른점은 팀의 인원 제한이 없다는 점이다.

재귀함수의 return에 대해서 좀 더 생각해봐야겠다

 

import sys

n=int(sys.stdin.readline().rstrip())
s=[]
for _ in range(n):
    s.append(list(map(int,sys.stdin.readline().rstrip().split())))

v=[0]*n
stack_index=[]
result=[]

def calc():
    sum1=0
    sum2=0
    for i in range(n):
        for j in range(n):
            if v[i]==v[j]:
                if v[i]==1:
                    sum1+=s[i][j]
                else:
                    sum2+=s[i][j]
    return abs(sum1-sum2)

def dfs(c=-1):
    if len(stack_index)==c:
        if len(result)==0:
            result.append(calc())
            return
        result[0]=min(result[0],calc())
        return
    for i in range(1,n//2+1):
        for j in range(n):
            if v[j]==1 or (len(stack_index)!=0 and stack_index[-1]>j):
                continue
            v[j]=1
            stack_index.append(j)
            dfs(i)
            stack_index.pop()
            v[j]=0
dfs()
print(result[0])

짜긴짰는데 시간초과가 나올 것 같은 이 기분

 

여기서 더 줄일수가 있다고?

import sys

n=int(sys.stdin.readline().rstrip())
s=[]
for _ in range(n):
    s.append(list(map(int,sys.stdin.readline().rstrip().split())))

v=[0]*n
stack_index=[]
result=[]

def calc():
    sum1=0
    sum2=0
    for i in range(n):
        for j in range(n):
            if v[i]==v[j]:
                if v[i]==1:
                    sum1+=s[i][j]
                else:
                    sum2+=s[i][j]
    return abs(sum1-sum2)

def dfs(c=-1):
    if len(stack_index)==c:
        if len(result)==0:
            result.append(calc())
            return
        result[0]=min(result[0],calc())
        return
    for i in range(1,n//2+1):
        for j in range(n):
            if len(result)!=0 and result[0]==0:
                return
            if v[j]==1 or (len(stack_index)!=0 and stack_index[-1]>j):
                continue
            v[j]=1
            stack_index.append(j)
            dfs(i)
            stack_index.pop()
            v[j]=0
            
dfs()
print(result[0])

0만나면 즉시 종료하게 작성했다 이걸로 해결인가 싶었지만 여전히 시간초과다

 

calc자체는 400번밖에 실행안돼서 괜찮을거라 생각했는데

기본적인 반복문이 n^2은 돌기때문에 매우 유의미해지는 듯하다

import sys

n=int(sys.stdin.readline().rstrip())
s=[]
for _ in range(n):
    s.append(list(map(int,sys.stdin.readline().rstrip().split())))

v=[0]*n
stack_index=[]
result=[]

def calc():
    sum1=0
    sum2=0
    sum1_set=[]
    sum2_set=[]

    for i in range(n):
        if v[i]==1:
            sum1_set.append(i)
        else:
            sum2_set.append(i)

    for i in range(len(sum1_set)):
        for j in range(len(sum1_set)):
            sum1+=s[sum1_set[i]][sum1_set[j]]
    for i in range(len(sum2_set)):
        for j in range(len(sum2_set)):
            sum2+=s[sum2_set[i]][sum2_set[j]]
    
    return abs(sum1-sum2)

def dfs(c=-1):
    if len(stack_index)==c:
        if len(result)==0:
            result.append(calc())
            return
        result[0]=min(result[0],calc())
        return
    for i in range(1,n//2+1):
        for j in range(n):
            if v[j]==1 or (len(stack_index)!=0 and stack_index[-1]>j):
                continue
            v[j]=1
            stack_index.append(j)
            dfs(i)
            stack_index.pop()
            v[j]=0
            if len(result)!=0 and result[0]==0:
                return
            
dfs()
print(result[0])

 

아주 살짝 유의미하게 빨라졌을것같긴한데 이걸로 통과가 될까싶다

 

아 알았다 재귀가 쓸데없이 많이 되는거였다

 

import sys

n=int(sys.stdin.readline().rstrip())
s=[]
for _ in range(n):
    s.append(list(map(int,sys.stdin.readline().rstrip().split())))

v=[0]*n
stack_index=[]
result=[]

def calc():
    sum1=0
    sum2=0
    sum1_set=[]
    sum2_set=[]

    for i in range(n):
        if v[i]==1:
            sum1_set.append(i)
        else:
            sum2_set.append(i)

    for i in range(len(sum1_set)):
        for j in range(len(sum1_set)):
            sum1+=s[sum1_set[i]][sum1_set[j]]
    for i in range(len(sum2_set)):
        for j in range(len(sum2_set)):
            sum2+=s[sum2_set[i]][sum2_set[j]]
    
    return abs(sum1-sum2)

def dfs(c=-1):
    if len(stack_index)==c:
        if len(result)==0:
            result.append(calc())
            return
        result[0]=min(result[0],calc())
        return
    if c==-1:
        for i in range(1,n//2+1):
            for j in range(n):
                if v[j]==1 or (len(stack_index)!=0 and stack_index[-1]>j):
                    continue
                v[j]=1
                stack_index.append(j)
                dfs(i)
                stack_index.pop()
                v[j]=0
                if len(result)!=0 and result[0]==0:
                    return
    else:
        for j in range(n):
            if v[j]==1 or (len(stack_index)!=0 and stack_index[-1]>j):
                continue
            v[j]=1
            stack_index.append(j)
            dfs(c)
            stack_index.pop()
            v[j]=0
            if len(result)!=0 and result[0]==0:
                return
            
dfs()
print(result[0])

전과는 비교도 안될정도로 빨라졌다

 

gpt행님이 재귀가 너무 많이 된다는걸 지적해줘서 겨우 풀었다.

 

완전 자력솔이 아니라 아쉽다.