Hi, I need help understanding this question solution. The problem is https://leetcode.com/problems/find-the-minimum-cost-array-permutation/
Basically we are given a permutation of 0 to n and have to construct another permutation $$$res$$$ that minimizes the function:
Can anyone explain to me the general idea of the solution? I see union find data structure being used and I assume:
start with 0 in the permutation because of cyclic property (any optimal answer can be cyclically rotated to have 0 at start)
ans[i] = element to follow i in permutation
We find the relation between nums[ans[i]] and i and try to 'merge' elements to minimize the absolute difference But i don't understand the check code. What is the dir and need arrays and why do we check !merge(i,i+1). Why is two adjacent nodes having a connection in nums array disqualify the candidate for next value (i.e ans[i]) ? Why are the ranges chosen while calling work() for x < i and x > i ?
What is the use of merging blocks and what is the significance?
class Solution {
public:
struct DSU {
std::vector<int> f, siz;
DSU(int n) : f(n), siz(n, 1) { std::iota(f.begin(), f.end(), 0); }
int leader(int x) {
while (x != f[x]) x = f[x] = f[f[x]];
return x;
}
bool same(int x, int y) { return leader(x) == leader(y); }
bool merge(int x, int y) {
x = leader(x);
y = leader(y);
if (x == y) return false;
siz[x] += siz[y];
f[y] = x;
return true;
}
};
vector<int> findPermutation(vector<int>& nums) {
int n = nums.size();
vector<int> ans(n, -1), res(n);
DSU dsu(n);
std::vector<bool> vis(n);
for (int i = 0; i < n; i++) {
if (vis[i]) {
continue;
}
for (int j = i; !vis[j]; j = nums[j]) {
vis[j] = true;
dsu.merge(j, i);
}
}
auto check = [&]() {
DSU g = dsu;
std::vector<bool> need(n - 1);
for (int i = 0; i < n; i++) {
if (ans[i] != -1) {
int x = nums[ans[i]];
if (x < i) {
for (int j = x; j < i; j++) {
need[j] = true;
}
} else {
for (int j = i; j < x; j++) {
need[j] = true;
}
}
}
}
for (int i = 0; i < n - 1; i++) {
if (need[i] && !g.merge(i, i + 1)) {
return false;
}
}
std::vector<int> dir(n - 2, -1);
std::vector<bool> cant(n - 1);
auto work = [&](int x, int d) {
if (dir[x] >= 0 && dir[x] != d) {
dir[x] = 0;
}
else {
dir[x] = d;
}
};
for (int i = 0; i < n; i++) {
if (ans[i] != -1) {
int x = nums[ans[i]];
if (x < i) {
for (int j = x; j + 1 < i; j++) {
work(j, 1);
}
if (x > 0) {
work(x - 1, 2);
}
if (i + 1 < n) {
work(i - 1, 2);
}
} else if (x > i) {
for (int j = x-1; j > i ; j--) {
work(j-1 , 2);
}
if (x + 1 < n) {
work(x - 1, 1);
}
if (i > 0) {
work(i - 1, 1);
}
} else {
if (x + 1 < n) {
cant[x] = true;
}
if (x > 0) {
cant[x - 1] = true;
}
}
}
}
for (int i = 0; i + 2 < n; i++) {
if (need[i] && need[i + 1] && dir[i] == 0) {
return false;
}
}
for (int i = 0; i + 1 < n; i++) {
if (need[i] && cant[i]) {
return false;
}
}
for (int i = 0; i < n - 1; i++) {
if (need[i]) {
continue;
}
if (cant[i]) {
continue;
}
if (i > 0 && dir[i - 1] == 0) {
continue;
}
if (i < n - 2 && dir[i] == 0) {
continue;
}
g.merge(i, i + 1);
}
for (int i = 0; i < n; i++) {
if (!g.same(0, i)) {
return false;
}
}
return true;
};
std::vector<bool> cyc(n);
int cnt = 0;
int j = 0;
for (int i = 0; ; i = ans[i]) {
res[j] = i;
j++;
cyc[i] = true;
cnt++;
ans[i] = 0;
while ((cnt < n && cyc[ans[i]]) || !check()) {
ans[i]++;
}
if(ans[i] == 0) break;
}
return res;
}
};
Auto comment: topic has been updated by SpongeCodes (previous revision, new revision, compare).