Hi all,
I was trying to solve http://www.spoj.com/problems/MCHAOS/. I think my solution is correct with the given constraints. I got TLE with BIT and STL maps initially in the 10th test case. Then I tried to optimize by converting strings into longs (preserving their order) and sorting long long[] instead of vector<string>. This gave WA. I believe there is some problem in the hash() function but I was not able to find it. Can someone help or suggest some other optimization?
Here are the two versions of the code...
#include <stdio.h>
#include <string>
#include <map>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long int ll;
const int NMAX = 100000+5;
ll BIT[NMAX];
int LIM = 0;
map<ll, int> dict;
char vs[NMAX][11];
vector<string> vecs, seq;
ll read (int v) {
ll sum = 0;
while (v) {
sum += BIT[v];
v -= v &-v;
}
return sum;
}
void update (int v, int val) {
while (v <= LIM) {
BIT[v] += val;
v += v & -v;
}
}
string reverse(string s) {
char c;
int l = s.length();
for (int i=0; i <= (l-1)/2; i++) {
c = s[i];
s[i] = s[l-i-1];
s[l-i-1] = c;
}
return s;
}
ll hash(string s){
ll sum = 0;
for (int i = 0; i < s.length(); i++) sum = sum * 26 + s[i]-26 + 1;
for (int i = s.length(); i <= 10; i++) sum *= 26;
return sum;
}
int main () {
int N;
memset(BIT, 0, sizeof(BIT));
memset(vs, 0, sizeof(vs));
cin >> N;
string s;
for (int i=0; i < N; i++) scanf("%s", vs[i]);
for (int i=0; i < N; i++) vecs.push_back(vs[i]);
sort(vecs.begin(), vecs.end());
for (int i = 0; i < N; i++) {vecs[i] = reverse(vecs[i]);seq.push_back(vecs[i]);}
sort(seq.begin(), seq.end());
for (int i = 0; i < N; i++) dict[hash(seq[i])] = i+1;
LIM = N+1;
ll bad_pairs = 0;
for (int i = N-1; i >= 0; i--) {
bad_pairs += read(dict[hash(vecs[i])]);
update(dict[hash(vecs[i])], 1);
}
cout << bad_pairs << endl;
return 0;
}
After the optimization ...
#include <stdio.h>
#include <string>
#include <map>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
typedef unsigned long long int ull;
const int NMAX = 100000+5;
ull BIT[NMAX];
int LIM = 0;
map<ull, int> dict,d2;
char vs[NMAX][12], sq[NMAX][12];
ull vecs[NMAX], seq[NMAX];
ull read (int v) {
ull sum = 0;
while (v) {
sum += BIT[v];
v -= (v &-v);
}
return sum;
}
void update (int v, int val) {
while (v <= LIM) {
BIT[v] += val;
v += (v & -v);
}
}
void reverse(char *s) {
char c;
int l = strlen(s);
for (int i=0; i <= l/2-1; i++) {
c = s[i];
s[i] = s[l-i-1];
s[l-i-1] = c;
}
}
ull hash(char *s){
ull sum = 0;
for (int i = 0; i < strlen(s); i++) sum = sum * 26 + s[i]-'a'+1;
for (int i = strlen(s); i < 10; i++) sum *= 26;
return sum;
}
int main () {
int N;
memset(BIT, 0, sizeof(BIT));
memset(vecs, 0, sizeof(vecs));
memset(seq, 0, sizeof(seq));
cin >> N;
string s;
for (int i=0; i < N; i++) scanf("%s", vs[i]);
for (int i=0; i < N; i++) {vecs[i] = hash(vs[i]);d2[vecs[i]] = i;}
sort(vecs, vecs+N);
for (int i = 0; i < N; i++) {
strcpy(sq[i],vs[d2[vecs[i]]]);
reverse(sq[i]);
seq[i] = hash(sq[i]);
}
sort(seq, seq+N);
for (int i = 0; i < N; i++) dict[seq[i]] = i+1;
LIM = N+1;
ull bad_pairs = 0;
for (int i = N-1; i >= 0; i--) {
bad_pairs += read(dict[hash(sq[i])]);
update(dict[hash(sq[i])], 1);
}
cout << bad_pairs << endl;
return 0;
}