Interpolation of a multivariable polynomial, or how to guess a formula
Разница между en1 и en2, 72 символ(ов) изменены
Hello, Codeforces. I wrote a prorgam using [this theorem](https://en.wikipedia.org/wiki/Combinatorial_Nullstellensatz?wprov=srpw1_1) for interpolation a function as a polynomial. At first, I will explain how does it work and then show some examples of using it.↵

Wikipedia does not provide one formula I need, so I will show it↵

<spoiler summary="Formula">↵
If there are now monomials with all exponents more or equal than our exponents, the coefficient with our monomial can be calculated by formula↵
![ ](/predownloaded/03/5e/035e5f0089c741b747d1816e0c773d0eaa138529.png)↵
</spoiler>↵

Code is quite long, but simple in use.↵

How to use program↵
------------------↵

At first, you need to set values of constants. N is the number of variables, MAX_DEG is the maximal exponent of a variable over all monomials. In `main` you need to fill 2 arrays: names contains the names of all variables, max_exp[i] is the maximal exponent of the i-th variable, or the upper_bound of its value. ↵

Define `d = (max_exp[0] + 1) * (max_exp[1] + 1) * ... * (max_exp[N - 1] + 1)`. MAX_PRODUCT should be greater then d. Then you need to write a function f(array<ll, N>), which returns ll or ld. In my code it returns an integer, but its type is ld to avoid ll overflow.↵

<spoiler summary="Code">↵
~~~~~↵
#include <bits/stdc++.h>↵
using namespace std;↵
#define ll long long↵
#define ld long double↵
#define fi first↵
#define se second↵
#define pb push_back↵
#define cok cout << (ok ? "YES\n" : "NO\n");↵
#define dbg(x) cout << (#x) << ": " << (x) << endl;↵
#define dbga(x,l,r) cout << (#x) << ": "; for (int ii=l;ii<r;ii++) cout << x[ii] << " "; cout << endl;↵
// #define int long long↵
#define pi pair<int, int>↵
const int N = 7, C = 1e7, MAX_DEG = 4, MAX_PRODUCT = 1e5;↵
const ld EPS = 1e-9, EPS_CHECK = 1e-9;↵
const string SEP = "  (", END = ")\n";↵
const bool APPROXIMATION = true;↵
array <string, N> names;↵
array <int, N> max_exp, powers, current_converted, cur_exp;↵
array<vector<ll>, N> POINTS;↵
ll DIV[N][MAX_DEG + 1][MAX_DEG + 1], PW[N][MAX_DEG + 1][MAX_DEG + 1];↵
ld SUM[MAX_PRODUCT];↵
ld F_CACHE[MAX_PRODUCT];↵
ll pow(ll a, int b)↵
{↵
if (b == 0) return 1;↵
if (b == 1) return a;↵
ll s = pow(a, b / 2);↵
s *= s;↵
if (b & 1) s *= a;↵
return s;↵
}↵
ld approximate(ld k)↵
{↵
int k_ = k;↵
int k__ = k_ + abs(k) / k;↵
if (abs(k - k_) < EPS) return k_;↵
else if (abs(k - k__) < EPS) return k__;↵
else↵
{↵
int i = 1, j = 1;↵
ld ka = abs(k);↵
while (i < C && j < C)↵
{↵
ld p = ka * j;↵
if (abs(p - i) < EPS) break;↵
if (p < i) j++;↵
else i++;↵
}↵
if (i >= C || j >= C) return k;↵
if (k < 0) i = -i;↵
return (ld)i / j;↵
} ↵
}↵
void normalize(ld k)↵
{↵
    if (!APPROXIMATION)↵
    {↵
        cout << k << SEP;↵
        return;↵
    }↵
int k_ = k;↵
int k__ = k_ + abs(k) / k;↵
if (abs(k - k_) < EPS) cout << k_ << SEP;↵
else if (abs(k - k__) < EPS) cout << k__ << SEP;↵
else↵
{↵
int i = 1, j = 1;↵
ld ka = abs(k);↵
while (i < C && j < C)↵
{↵
ld p = ka * j;↵
if (abs(p - i) < EPS) break;↵
if (p < i) j++;↵
else i++;↵
}↵
if (i >= C || j >= C)↵
{↵
cout << k << SEP;↵
return;↵
}↵
if (k < 0) i = -i;↵
cout << i << "/" << j << SEP;↵
}↵
}↵
struct monom↵
{↵
array<int, N> exp;↵
ld k;↵
int deg;↵
monom(array<int, N> v, ld k_)↵
{↵
k = k_;↵
exp = v;↵
deg = 0;↵
for (int i=0;i<N;i++) deg += exp[i];↵
}↵
void display()↵
{↵
normalize(k);↵
if (deg == 0) { cout << "1" << END; return;}↵
bool go = 0;↵
for (int i=0;i<N;i++)↵
{↵
if (go && exp[i]) cout << " * ";↵
if (exp[i]) go = 1, cout << names[i] + "^" + to_string(exp[i]);↵
}↵
cout << END;↵
}↵
ld operator()(array<int, N> v)↵
{↵
ll res = 1;↵
for (int i=0;i<N;i++) res *= PW[i][v[i]][exp[i]];↵
return k * res;↵
}↵
ld getRandom(array<ll, N> v)↵
{↵
ld res = 1;↵
for (int i=0;i<N;i++) res *= pow(v[i], exp[i]);↵
return k * res;↵
}↵
};↵
bool operator<(monom a, monom b)↵
{↵
if (a.deg > b.deg) return 1;↵
if (a.deg < b.deg) return 0;↵
if (a.exp > b.exp) return 1;↵
if (a.exp < b.exp) return 0;↵
return a.k > b.k;↵
}↵
struct polynom↵
{↵
vector<monom> st;↵
void add(monom m)↵
{↵
if (abs(m.k) < EPS) return;↵
st.pb(m);↵
}↵
void print() { if(st.size() == 0) {cout << "Polynom is 0\n"; return;} sort(st.begin(), st.end()); for (monom &m: st) m.display();}↵
ld operator()(array<ll, N> v)↵
{↵
ld res = 0;↵
for (auto &m: st) res += m.getRandom(v);↵
return res;↵
}↵
};↵
ld gen(int index=0, int current_hash=0)↵
{↵
if (index == N)↵
{↵
ll div = 1;↵
for (int i=0;i<N;i++) div *= DIV[i][current_converted[i]][cur_exp[i]];↵
return (ld)(F_CACHE[current_hash] - SUM[current_hash]) / div;↵
}↵
ld res = 0;↵
for (int i=0;i<=cur_exp[index];i++)↵
{↵
current_converted[index] = i;↵
res += gen(index + 1, current_hash + i * powers[index]);↵
}↵
return res;↵
}↵
array<int, N> convert(int h)↵
{↵
array<int, N> res;↵
for (int i=0;i<N;i++) res[i] = h / powers[i], h -= res[i] * powers[i];↵
return res;↵
}↵
array<ll, N> convert_points(int h)↵
{↵
array<ll, N> res;↵
for (int i=0;i<N;i++) res[i] = POINTS[i][h / powers[i]], h %= powers[i];↵
return res;↵
}↵
polynom interpolate(ld f(array<ll, N>))↵
{↵
    int max_pow = -2e9, sum = 0, h_max = 0;↵
    set<int> remaining_points, st;↵
polynom res;↵
    for (int x: max_exp) max_pow = max(max_pow, x), sum += x, h_max = h_max * (x + 1) + x;↵

    powers[N - 1] = 1;↵
    for (int i=N-2;i>-1;i--) powers[i] = powers[i + 1] * (max_exp[i + 1] + 1);↵

    for (int i=0;i<max_exp.size();i++) for (int j=0;j<=max_exp[i];j++) POINTS[i].pb(j);↵

    for (int i=0;i<N;i++) for (int j=0;j<=max_exp[i];j++) for (int u=0;u<=max_exp[i];u++) DIV[i][j][u] = (u ? DIV[i][j][u - 1] : 1) * (u == j ? 1 : (POINTS[i][j] - POINTS[i][u]));↵

    for (int i=0;i<N;i++) for (int j=0;j<=max_exp[i];j++) for (int u=0;u<=max_pow;u++) PW[i][j][u] = u ? PW[i][j][u - 1] * POINTS[i][j] : 1;↵

    for (int i=0;i<=h_max;i++) F_CACHE[i] = f(convert_points(i)), remaining_points.insert(i);↵
    st.insert(h_max);↵

    while (st.size())↵
{↵
int v = *st.rbegin();↵
st.erase(v);↵
remaining_points.erase(v);↵
cur_exp = convert(v);↵
ld k = gen();↵
if (abs(k) > EPS)↵
{↵
monom mn = monom(cur_exp, k);↵
if (APPROXIMATION) k = approximate(k);↵
monom mn = monom(cur_exp, k);↵
res.add(mn);↵
for (int i: remaining_points) SUM[i] += mn(convert(i));↵
}↵
for (int i=0;i<N;i++) if (cur_exp[i]) st.insert(v - powers[i]);↵
}↵
return res;↵
}↵
ld f(array<ll, N> v)↵
{↵
auto [a, b, c, d, e, f, g] = v;↵
ld res = 0;↵
for (int i=0;i<a;i++)↵
for (int j=0;j<b;j++)↵
for (int u=0;u<c;u++)↵
for (int x=0;x<d;x++)↵
for (int y=0;y<e;y++)↵
for (int z=0;z<f;z++)↵
for (int k=0;k<g;k++)↵
res += 13ll * i * j * u * i * i * u - 49ll * k * k * z * z * y + 90ll * c * u * k * x * x * x;↵
return res;↵
}↵
void check(polynom p, ld(array<ll, N> f))↵
{↵
mt19937 rnd(228);↵
for (int i=0;i<10000;i++)↵
{↵
int t = clock();↵
array<ll, N> ex;↵
for (int j=0;j<N;j++) ex[j] = rnd() % 20 + 2;↵
ld F = f(ex);↵
ld P = p(ex);↵
if (abs(F - P) > max(EPS_CHECK, EPS_CHECK * abs(F)))↵
{↵
cout << "Polynom is wrong, test " << i << endl;↵
cout << F << endl << P << endl;↵
for (int x: ex) cout << x << " ";↵
cout << endl;↵
return;↵
}↵
cout << "Test " << i << " has been passed, time = " << (ld)(clock() - t) / CLOCKS_PER_SEC << "s" << endl;↵
}↵
cout << "Polynom is OK" << endl;↵
}↵
signed main()↵
{↵
    cin.tie(0); ios_base::sync_with_stdio(0);↵
    cout << setprecision(20) << fixed;↵

    names = {"a", "b", "c", "d", "e", "f", "g"};↵
    max_exp = {4, 2, 3, 4, 2, 3, 3};↵
    ↵
    polynom P = interpolate(f);↵
    P.print();↵
    //cout << "Checking polynom..." << endl;↵
    //check(P, f);↵
}↵
~~~~~↵
</spoiler>↵

#### Stress-testing↵
If you uncomment 2 last rows in main, program will check the polynom it got on random tests. The test generation should depend from `f` working time, because it can run too long on big tests.↵

#### Approximations↵
Function from exaple and similar functions (with n loops) are a polynomial with rational coefficients (if this is not true, the function does not return an integer). So, if APPROXIMATION = true, all coefficients are approximating to rational numbers with absolute error < EPS with functions `normalize` and `approximate` (they are using the same algorithm).↵
This algorithm works in O(numerator + denominator), that seems to be slow, but if the polynomial has a small amount of monomials, it does not take much time.↵

Stress-testing function `check` considers a value correct if its absolute or relative error < EPS_CHECK.  ↵

### How and how long does it work↵

We consider monomials as arrays of exponents. We hash these arrays. Array PW contains powers of points (from POINTS), which we use for interpolation. If you want to use your points for interpolation, modify POINTS. If you use fractional numbers, replace `#define ll long long` with `#define ll long double`.↵
Array DIV is used for fast calculating denominators in the formula. ↵

`convert(h)` &mdash; get indexes (in array POINTS) of coordinates of the point corresponding to the monomial with hash h.↵
`convert_points(h)` &mdash; get coordinates of the point corresponding to the monomial with hash h.↵

Then we are precalcing values of `f` in all our points and write them to `F_CACHE`. After it, we run bfs on monomials. During the transition from one monomial to another we decrease the exponent of one variable by 1.↵
When a monomial is got from set in bfs, we find its coefficient using `gen`. If it is not zero, we need to modify our polynomial for every monomials we has not considered in bfs yet ("monomial" and "point" have the same concepts because we can get a point from monomial using convert_points(h), if h is a hash of the monomial).↵

We need to modify the polynomial to make one of the theorem's conditions satisfied: there are no monomials greater than our monomial (greater means that all exponents are more or equal). For every point we has not consider in bfs (they will be in set `remaining_points`) we add the value of the monomial in this point to SUM[hash_of_point]. Then we will decrease `f(point)` by `SUM[hash_of_point]` to artificially remove greater monomials.↵

#### Time complexity↵
1. The longest part of precalc &mdash; calculating F_CACHE &mdash; take O(d * O(f)) time↵
2. Each of d runs of `gen` is iterating over O(d) points, denominator calculation takes O(N) time for each point.↵
3. For every monomial with non-zero coefficient we calculate its values in O(d) points in O(N) for each point.↵

We have got `O(d * O(f) + d^2 * N + d * O(res))`, where `O(res)` is the time it takes to calculate the polynomial we got as a result.↵

### Trying to optimize↵

It seems that the recursion takes the most time. We can unroll it in one cycle using stack. It is boring, so I decided to try to unroll it in other way. For every monomial with non-zero coefficient, let`s iterate over all monomials with hash less or equal to our hash. For every monomial we check if it is less than our monomial (all corresponding exponents are less or equal). If it is lower, we add to the coefficient the value of fraction in this point (monomial).↵


~~~~~↵
// Instead of ld k = gen();↵
ld k = 0;↵
for (int h=0;h<=v;h++)↵
{↵
    array<int, N> cur = convert(h);↵
    bool ok = 1;↵
    for (int i=0;i<N;i++) if (cur[i] > cur_exp[i]) ok = 0;↵
    if (ok)↵
    {↵
ll div = 1;↵
        for (int i=0;i<N;i++) div *= DIV[i][cur[i]][cur_exp[i]];↵
        k += (ld)(F_CACHE[h] - SUM[h]) / div;↵
    }↵
}↵
~~~~~↵

Is it faster than `gen`? New implementation is iterating over all pairs of hashes, so it works in `O(d^2 * N)`, too. Let's estimate the constant. The number of these pairs is d * (d + 1) / 2, so we get constant 1 / 2. Now let's calculate the constant of number of points considered by `gen`. This number can be calculated with this function:↵

~~~~~↵
ld f(array<ll, N> v)↵
{↵
auto [a, b, c, d, e, f, g] = v;↵
ld res = 0;↵
for (int i=0;i<a;i++)↵
for (int j=0;j<b;j++)↵
for (int u=0;u<c;u++)↵
for (int x=0;x<d;x++)↵
for (int y=0;y<e;y++)↵
for (int z=0;z<f;z++)↵
for (int k=0;k<g;k++)↵
res += (i + 1) * (j + 1) * (u + 1) * (x + 1) * (y + 1) * (z + 1) * (k + 1);↵
return res;↵
}↵
~~~~~↵

The coefficient with `a^2 * b^2 * c^2 * d^2 * e^2 * f^2` is our constant. To find it, I used my program. It is 1 / 128. At all, it is `1 / 2^N` for N variables. It means that the optimization can be efficient if N is small.↵

### Conclusion↵

May be, this program will help someone to find formula for some function. Also it can open brackets, that is necessary if you calculate geometry problems in complex numbers. If you know other ways to use it, I will be happy if you share it.↵

With N = 1 this program is just a Lagrange interpolation, which can be done faster than `O(d^2)`. Maybe, someone will find a way to speed up it with N > 1.↵

История

 
 
 
 
Правки
 
 
  Rev. Язык Кто Когда Δ Комментарий
en2 Английский polosatic 2023-05-12 14:42:44 72
ru9 Русский polosatic 2023-05-12 14:41:57 68
en1 Английский polosatic 2023-05-10 20:14:30 13022 Initial revision for English translation
ru8 Русский polosatic 2023-05-10 19:39:53 16 Мелкая правка: 'N = 1 эта функция &mdash; п' -> 'N = 1 эта программа &mdash; п'
ru7 Русский polosatic 2023-05-10 19:36:11 16 Мелкая правка: 'му-то эта функция поможет у' -> 'му-то эта программа поможет у'
ru6 Русский polosatic 2023-05-10 19:31:20 54
ru5 Русский polosatic 2023-05-10 19:15:31 2 Мелкая правка: 'очек за O(n)\n\nПолуч' -> 'очек за O(N)\n\nПолуч'
ru4 Русский polosatic 2023-05-10 19:14:48 4
ru3 Русский polosatic 2023-05-10 18:21:17 5 Мелкая правка: 'ет `array<int, N>`, а в' -> 'ет `array<ll, N>`, а в'
ru2 Русский polosatic 2023-05-10 18:00:50 15 Мелкая правка: 'е, чем за `O(d^2)`. Возможно' -> 'е, чем за квадрат. Возможно'
ru1 Русский polosatic 2023-05-10 17:52:39 13140 Первая редакция (опубликовано)