[UCPC 2023 예선 C번]
[Gold I]
https://www.acmicpc.net/problem/28297
풀이
수학적 지식과, MST(Minimum Spanning Tree)를 이용해 풀 수 있었던 문제.
대회 당일에는, 시간을 얼마 안 남겨놓고 풀이를 시작해서, 조급하게 Union-Find로 접근하다 O(N^3)으로 TLE를 냈는데, 끝나고 다시 보니 MST로 풀이하면 된다는 것을 깨달았다...!
우선, 두 점 사이의 유클리드 거리는 아래와 같이 구할 수 있다. (line 46)
두 기어가 어떻게든 맞닿아 있다면, 그 기어들은 연결되어 같이 돌아가게 되므로, union해 준다.
=> (두 기어의 중심 사이의 거리) > (두 기어의 반지름의 합)
일 때, union. (line 89~97)
기어가 맞닿아 있지 않다면 (union되어 있지 않다면), 벨트를 연결해야 한다.
모든 기어가 연결된다면 (모든 기어가 union되어 있다면), 문제에서 요구하는 바를 만족하게 된다.
따라서, 모든 기어를 정점(vertex)로 본다면, 그 정점들을 모두 연결하면 되며, 그 과정에서 필요한 벨트의 길이가 최소가 되도록 해야 하므로, MST이다.
연결되지 않은 기어(정점)와 기어를 연결했을 때의 필요한 벨트의 길이를 간선(edge)로 간주하고, 모든 정점 사이의 간선 길이를 구한 후, Kruskal 알고리즘을 통해 MST를 만들면 문제가 해결된다.
이 때, 총 시간 복잡도는 모든 간선을 구하는 시간에 좌우되므로, O(N^2)이다.
그럼 남은 것은, 기어를 연결할 때 필요한 벨트의 길이를 구하는 방법이다. (line 49~73)
위와 같은 그림에서, 우리가 구해야 할 벨트의 길이는 아래와 같다.
우선 h는 피타고라스의 법칙으로 구할 수 있다. (line 68)
l1 과 l2 는 원주를 구하는, 아래 공식으로 구할 수 있다.
따라서, 아래와 같다.
이 때, θ를 구하기 위해서는, 삼각함수를 이용해야 하는데,
이므로,
이다. (line 69)
아래 코드에서는, θ를 다르게 잡아 cos이 아닌, sin을 이용해 계산했으나, 답은 동일하게 계산된다.
이를 이용해 두 기어를 연결할 때 필요한 벨트의 길이를 계산하여 모든 간선의 길이를 구하고, 이를 이용해 Kruskal로 MST를 만들면 끝!
#include <math.h>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <numeric>
#include <queue>
#include <stack>
#include <tuple>
#include <vector>
#define init cin.tie(0)->ios_base::sync_with_stdio(0);
using namespace std;
typedef long long ll;
typedef pair<int, int> pint;
typedef tuple<int, int, int> iii;
typedef pair<ll, ll> pll;
int parent[1004];
int N;
vector<iii> cir;
int find(int x) {
if (parent[x] == x)
return x;
else
return parent[x] = find(parent[x]);
}
bool _union(int x, int y) {
x = find(x);
y = find(y);
if (x == y) return false;
parent[min(x, y)] = max(x, y);
return true;
}
double distance(int x1, int y1, int x2, int y2) {
return sqrt(pow(y2 - y1, 2) + pow(x2 - x1, 2));
}
double beltDist(iii cir1, iii cir2) {
int x1, y1, r1, x2, y2, r2;
if (get<2>(cir1) > get<2>(cir2)) {
x1 = get<0>(cir1);
y1 = get<1>(cir1);
r1 = get<2>(cir1);
x2 = get<0>(cir2);
y2 = get<1>(cir2);
r2 = get<2>(cir2);
} else {
x1 = get<0>(cir2);
y1 = get<1>(cir2);
r1 = get<2>(cir2);
x2 = get<0>(cir1);
y2 = get<1>(cir1);
r2 = get<2>(cir1);
}
double d = distance(x1, y1, x2, y2);
double x = sqrt(pow(d, 2) - pow(r1 - r2, 2));
double theta = asin((r1 - r2) / d);
double res = 2 * x + r2 * (M_PI - 2 * theta) + r1 * (M_PI + 2 * theta);
return res;
}
int main() {
init;
for (int i = 0; i < 1004; i++) {
parent[i] = i;
}
cin >> N;
int x, y, r;
for (int i = 0; i < N; i++) {
cin >> x >> y >> r;
cir.push_back({x, y, r});
}
// 거리 측정 후 이미 닿아 있는 원은 union.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
if (i == j) continue;
double dis = distance(get<0>(cir[i]), get<1>(cir[i]), get<0>(cir[j]), get<1>(cir[j]));
if (dis <= get<2>(cir[i]) + get<2>(cir[j])) {
_union(i, j);
}
}
}
double v[N][N];
// 모든 점 사이의 거리를 측정 (이미 연결된 원 제외))
vector<tuple<double, int, int>> edges;
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
if (i == j || find(i) == find(j)) continue;
edges.push_back({beltDist(cir[i], cir[j]), i, j});
}
}
// 거리순 오름차순 정렬
sort(edges.begin(), edges.end());
double res = 0;
// Kruskal(MST)
for (int i = 0; i < edges.size(); i++) {
double dist = get<0>(edges[i]);
int v1 = get<1>(edges[i]);
int v2 = get<2>(edges[i]);
if (find(v1) != find(v2)) {
res += dist;
_union(v1, v2);
}
}
cout.precision(20);
cout << res;
}