Can anyone help me with the last div2 round's E?

Revision en1, by xcx0902, 2024-01-28 15:20:53

It gives WA on test 8.

#pragma GCC optimize(2, 3, "Ofast", "inline")
#include <bits/stdc++.h>
#define int long long
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;

using ll = __int128;

const int N = 3e5 + 5, M = N << 2, mod = 1000000000000000009LL;
int n, m, q, x[N], v[N], sum[M], mul[M], add[M];
set<pair<int, int>> s;

int inv(int x) {
    if (x == 1) return 1;
    return (ll)(mod - mod / x) * inv(mod % x) % mod;
}

auto findLeft(int p) {
    return prev(s.upper_bound({p, 1e9}));
}

auto findRight(int p) {
    return s.upper_bound({p, 0});
}

int calc(int p) {
    return findLeft(p)->second * (findRight(p)->first - p);
}

void pushup(int p) {
    sum[p] = (sum[ls] + sum[rs]) % mod;
}

void build(int p, int l, int r) {
    // cerr << "build " << p << " " << l << " " << r << endl;
    mul[p] = 1;
    add[p] = 0;
    if (l == r) {
        sum[p] = calc(l);
        return;
    }
    build(ls, l, mid);
    build(rs, mid + 1, r);
    pushup(p);
}

void pushdown(int p, int l, int r) {
    // cerr << "pushdown " << p << " " << l << " " << r << endl;
    sum[ls] = ((ll)sum[ls] * mul[p] + (ll)add[p] * (mid - l + 1)) % mod;
    sum[rs] = ((ll)sum[rs] * mul[p] + (ll)add[p] * (r - mid)) % mod;
    mul[ls] = ((ll)mul[ls] * mul[p]) % mod;
    mul[rs] = ((ll)mul[rs] * mul[p]) % mod;
    add[ls] = ((ll)add[ls] * mul[p] + add[p]) % mod;
    add[rs] = ((ll)add[rs] * mul[p] + add[p]) % mod;
    mul[p] = 1;
    add[p] = 0;
}

void update1(int p, int l, int r, int L, int R, int k) {
    // cerr << "update1 " << p << " " << l << " " << r << " " << L << " " << R << " " << k << endl;
    if (L <= l && r <= R) {
        add[p] = (add[p] + k) % mod;
        sum[p] = (sum[p] + k * (r - l + 1)) % mod;
        return;
    }
    pushdown(p, l, r);
    if (L <= mid) update1(ls, l, mid, L, R, k);
    if (R > mid) update1(rs, mid + 1, r, L, R, k);
    pushup(p);
}

void update2(int p, int l, int r, int L, int R, int k) {
    // cerr << "update2 " << p << " " << l << " " << r << " " << L << " " << R << " " << k << endl;
    if (L <= l && r <= R) {
        sum[p] = (ll)sum[p] * k % mod;
        mul[p] = (ll)mul[p] * k % mod;
        add[p] = (ll)add[p] * k % mod;
        return;
    }
    pushdown(p, l, r);
    if (L <= mid) update2(ls, l, mid, L, R, k);
    if (R > mid) update2(rs, mid + 1, r, L, R, k);
    pushup(p);
}

int query(int p, int l, int r, int L, int R) {
    // cerr << "query " << p << " " << l << " " << r << " " << L << " " << R << endl;
    if (L <= l && r <= R) return sum[p];
    pushdown(p, l, r);
    int res = 0;
    if (L <= mid) res = (res + query(ls, l, mid, L, R)) % mod;
    if (R > mid) res = (res + query(rs, mid + 1, r, L, R)) % mod;
    return res;
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n >> m >> q;
    for (int i = 1; i <= m; i++) cin >> x[i];
    for (int i = 1; i <= m; i++) cin >> v[i];
    for (int i = 1; i <= m; i++) s.emplace(x[i], v[i]);
    build(1, 1, n);
    // cerr << "Segment Tree built" << endl;
    while (q--) {
        int op;
        cin >> op;
        if (op == 1) {
            int nx, nv;
            cin >> nx >> nv;
            pair<int, int> L = *findLeft(nx);
            pair<int, int> R = *findRight(nx);
            // cerr << "L = {" << L.first << ", " << L.second << "}" << endl;
            // cerr << "R = {" << R.first << ", " << R.second << "}" << endl;
            s.emplace(nx, nv);
            update1(1, 1, n, L.first, nx, -L.second * (R.first - nx));
            update2(1, 1, n, nx + 1, R.first, (ll)nv * inv(L.second) % mod);
        } else {
            int l, r;
            cin >> l >> r;
            cout << query(1, 1, n, l, r) << endl;
        }
    }
    return 0;
}

History

 
 
 
 
Revisions
 
 
  Rev. Lang. By When Δ Comment
en1 English xcx0902 2024-01-28 15:20:53 3948 Initial revision (published)