본문 바로가기

짜잘한 기록

백준 15481 그래프와 MST

역시 LCA를 활용하는 문제이다. 저번 정점들의 거리 문제에서 MST의 개념을 살짝 붙이기만 하면 방법을 찾을 수 있다.

에지를 입력받을 때, 인접 행렬을 직접 만드는 것이 아니라, 에지만 저장하는 벡터/배열을 선언하고 나중에 그 에지들을 순회하며 MST를 찾을 수 있게 한다.

 

    std::vector<std::pair<int, long long> > adj[N];
    std::priority_queue<edge, std::vector<edge>, cmp> pq;
    std::vector<edge> vEdgeList;

 

일단 에지를 입력받고 난 다음, 그 에지들을 사용해 MST를 계산한다. MST를 계산하면서 adj(인접행렬)을 채워 이 트리를 사용해 LCA를 구하게 된다.

 

for (int i = 0; i < M; i++)
    {
        int s, e;
        long long c;
        std::cin >> s >> e >> c;
        s--;
        e--;
        pq.push(edge(s, e, c));
        vEdgeList.push_back(edge(s, e, c));
        
    }
    memset(pow2parents, -1, sizeof(pow2parents));
    std::fill(depth, depth+N, -1);
    depth[0] = 0;
    for (int i = 0; i < MAX_H; i++)
    {
        for (int j = 0; j < MAX_N; j++)
        {
            distMax[i][j] = 0;
            if (i == 0)
                UFparents[j] = j;
        }
    }

    //MST 구하기
    long long lMSTcost = 0;
    int iedgeCount = 0;

    while (!pq.empty())
    {
        if (iedgeCount == N - 1)
            break;
        edge cur = pq.top();
        pq.pop();
        if (findParents(cur.start) == findParents(cur.end))
            continue;
        else
        {
            unionElement(cur.start, cur.end);
            lMSTcost += cur.cost;
            iedgeCount++;
            adj[cur.start].push_back(std::pair<int, long long> (cur.end, cur.cost));
            adj[cur.end].push_back(std::pair<int, long long> (cur.start, cur.cost));
        }
    }

 

이제 준비가 다 되었다.

특정 에지를 포함하는 MST는, 그 에지를 추가하고 생긴 사이클에서 가장 cost가 높은 것을 없애 만들 수 있다. 여기서 LCA를 사용하게 되는데, 에지가 잇는 두 노드 간 최소 경로를 구한다(MST에서). 그 경로는 해당 에지를 추가하게 되면 사이클이 된다.

따라서, 그 경로에서 가장 cost가 높은 에지를 구해 트리에서 제외시키면 된다.

문제에서는 단순히 해당 MST의 Cost를 구하라고 하였기에 원래 MST 비용 + 추가된 에지 cost - 두 노드 경로 간 최대 cost 를 구하면 된다.

 

for (int i = 0; i < vEdgeList.size(); i++)
    {
        int iv = vEdgeList[i].start;
        int iu = vEdgeList[i].end;
        long long laddCost = vEdgeList[i].cost;
        long long lMaxcost = 0;

        int iDepthDiff;
        if (depth[iv] < depth[iu])
            std::swap(iv, iu);
        iDepthDiff = depth[iv] - depth[iu];

        for (int bit = 0; iDepthDiff; bit++)
        {
            if (iDepthDiff % 2)
            {
                lMaxcost = std::max(lMaxcost, distMax[bit][iv]);
                iv = pow2parents[bit][iv];
            }
            iDepthDiff /= 2;
        }

        if (iu != iv)
        {
            for (int K = MAX_H - 1; K >= 0; K--)
            {
                if (pow2parents[K][iu] != -1 && pow2parents[K][iu] != pow2parents[K][iv])
                {
                    lMaxcost = std::max(lMaxcost, std::max(distMax[K][iu], distMax[K][iv]));
                    iu = pow2parents[K][iu];
                    iv = pow2parents[K][iv];
                }
            }
            lMaxcost = std::max(lMaxcost, std::max(distMax[0][iu], distMax[0][iv]));
        }
        std::cout << lMSTcost + laddCost - lMaxcost << "\n";
    }

 

여기서 쓰인 distMax[K][V] 배열은, V 노드에서 2^K 조상까지 경로 중 최대 cost를 담고 있다.


전체 코드는 다음과 같다.

 

#include <iostream>
#include <vector>
#include <queue>
#include <string.h>
#include <climits>

#define MAX_H 19 // ceil(log2(200000))
#define MAX_N 200000
//bound???
int pow2parents[MAX_H+1][MAX_N+1];
long long distMax[MAX_H+1][MAX_N+1];
int depth[MAX_N+1];
int UFparents[MAX_N+1];

typedef struct edge {
    int start;
    int end;
    long long cost;
    edge(int s, int e, int c) : start(s), end(e), cost(c) {}
} edge;

struct cmp{
    bool operator()(const edge e1, const edge e2)
    {
        return e1.cost > e2.cost;
    }
};

int findParents(int Ielement)
{
    if (UFparents[Ielement] == Ielement)
        return Ielement;
    return UFparents[Ielement] = findParents(UFparents[Ielement]);
}

void unionElement(int Ielement1, int Ielement2)
{
    int Ip1 = findParents(Ielement1);
    int Ip2 = findParents(Ielement2);

    UFparents[Ip1] = Ip2;
}

void initLCA(int icurNode, std::vector<std::pair<int, long long> > adj[])
{
    for (int i = 0; i < adj[icurNode].size(); i++)
    {
        int inext = adj[icurNode][i].first;
        long long inextPathCost = adj[icurNode][i].second;
        if (depth[inext] == -1)
        {
            depth[inext] = depth[icurNode] + 1;
            pow2parents[0][inext] = icurNode;
            distMax[0][inext] = inextPathCost;
            initLCA(inext, adj);
        }
    }
}
int main()
{
    std::ios::sync_with_stdio(0);
    std::cin.tie(0);
    std::cout.tie(0);
    int N, M;
    std::cin >> N >> M;

    std::vector<std::pair<int, long long> > adj[N];
    std::priority_queue<edge, std::vector<edge>, cmp> pq;
    std::vector<edge> vEdgeList;

    for (int i = 0; i < M; i++)
    {
        int s, e;
        long long c;
        std::cin >> s >> e >> c;
        s--;
        e--;
        pq.push(edge(s, e, c));
        vEdgeList.push_back(edge(s, e, c));
        
    }
    memset(pow2parents, -1, sizeof(pow2parents));
    std::fill(depth, depth+N, -1);
    depth[0] = 0;
    for (int i = 0; i < MAX_H; i++)
    {
        for (int j = 0; j < MAX_N; j++)
        {
            distMax[i][j] = 0;
            if (i == 0)
                UFparents[j] = j;
        }
    }

    //MST 구하기
    long long lMSTcost = 0;
    int iedgeCount = 0;

    while (!pq.empty())
    {
        if (iedgeCount == N - 1)
            break;
        edge cur = pq.top();
        pq.pop();
        if (findParents(cur.start) == findParents(cur.end))
            continue;
        else
        {
            unionElement(cur.start, cur.end);
            lMSTcost += cur.cost;
            iedgeCount++;
            adj[cur.start].push_back(std::pair<int, long long> (cur.end, cur.cost));
            adj[cur.end].push_back(std::pair<int, long long> (cur.start, cur.cost));
        }
    }
    

    initLCA(0, adj);
    for (int K = 0; K < MAX_H - 1; K++)
    {
        for (int V = 1; V < N; V++)
        {
            if (pow2parents[K][V] != -1)
            {
                pow2parents[K+1][V] = pow2parents[K][pow2parents[K][V]];
                distMax[K+1][V] = std::max(distMax[K][V], distMax[K][pow2parents[K][V]]);
            }
        }
    }

    for (int i = 0; i < vEdgeList.size(); i++)
    {
        int iv = vEdgeList[i].start;
        int iu = vEdgeList[i].end;
        long long laddCost = vEdgeList[i].cost;
        long long lMaxcost = 0;

        int iDepthDiff;
        if (depth[iv] < depth[iu])
            std::swap(iv, iu);
        iDepthDiff = depth[iv] - depth[iu];

        for (int bit = 0; iDepthDiff; bit++)
        {
            if (iDepthDiff % 2)
            {
                lMaxcost = std::max(lMaxcost, distMax[bit][iv]);
                iv = pow2parents[bit][iv];
            }
            iDepthDiff /= 2;
        }

        if (iu != iv)
        {
            for (int K = MAX_H - 1; K >= 0; K--)
            {
                if (pow2parents[K][iu] != -1 && pow2parents[K][iu] != pow2parents[K][iv])
                {
                    lMaxcost = std::max(lMaxcost, std::max(distMax[K][iu], distMax[K][iv]));
                    iu = pow2parents[K][iu];
                    iv = pow2parents[K][iv];
                }
            }
            lMaxcost = std::max(lMaxcost, std::max(distMax[0][iu], distMax[0][iv]));
        }
        std::cout << lMSTcost + laddCost - lMaxcost << "\n";
    }
}

 

문제에서 한 삽질을 나열하자면,

첫 시도는 그냥 입력받은 에지를 순회하면서 매 에지마다 MST를 구했다.

오 답 잘나와서 넣었는데 당연히 Timeout...

두 번째로 LCA를 붙여서 구현하였는데, 분명히 가지고 있는 테스트케이스는 맞는데 중간에 계속 틀려서 확인해보니, Cost를 계산하는 변수들은 int로 할 경우 overflow가 나더라... 이래서 문제를 잘 읽고 풀어야 되나보다... 해당 변수를 long long 로 바꿔줬다.

세 번째로 잘 돌아가다가 Timeout이 났다. 와 LCA 써서 잘 했는데 왜? 왜 타임아웃인거야? 해서 게시판을 보았다. 이 문제같이 입출력이 많은 문제는

 

std::ios::sync_with_stdio(0);
std::cin.tie(0);
std::cout.tie(0);

 

를 해줘야 입출력 병목이 생기지 않는다고 한다... 이것도 알고는 있었는데 최근에는 입출력 크리티컬한 문제를 본게 아니라 자연스럽게 잊고 있었다.

수정하고 제출하니 PASS...

 

문제를 풀때, 문제를 읽고, 생각한것은 직접 메모하면서 풀어야겠다.

 

오늘도 평온한 하루가 되길. 슨민.

'짜잘한 기록' 카테고리의 다른 글

[나만 몰랐던 알고리즘] 2D 맵 BFS  (0) 2021.09.06
백준 1626 두 번째로 작은 스패닝 트리  (0) 2021.09.06
백준 1761 정점들의 거리  (0) 2021.09.04
백준 3176 도로 네트워크  (0) 2021.09.02
백준 11438 LCA 2  (0) 2021.09.02