USACO2022美国公开赛白金组问题一解决方案

(Analysis by Benjamin Qi)
I will refer to contiguous sequences as intervals. Define the value of an interval to be the minimum possible final number it can be converted into.
Subtask 1: Similar to 248, we can apply dynamic programming on ranges. Specifically, if dp[i][j] denotes the value of interval i…j, then
dp[i][j]={Aaimini≤k<jmax(dp[i][k],dp[k+1][j])+1i=ji<j.
The time complexity is O(N3).
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
template <class T> void ckmin(T &a, const T &b) { a = min(a, b); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> dp(N, V<int>(N));
for (int i = N - 1; i >= 0; --i) {
dp.at(i).at(i) = A.at(i);
for (int j = i + 1; j < N; ++j) {
dp[i][j] = INT_MAX;
for (int k = i; k < j; ++k) {
ckmin(dp.at(i).at(j),
max(dp.at(i).at(k), dp.at(k + 1).at(j)) + 1);
}
}
}
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
for (int j = i; j < N; ++j) {
ans += dp.at(i).at(j);
}
}
cout << ans << "\n";
}
Subtask 2: We can optimize the solution above by more quickly finding the maximum k′ such that dp[i][k′]≤dp[k′+1][j]. Then we only need to consider k∈{k′,k′+1} when computing dp[i][j]. Using the observation that k′ does not decrease as j increases and i is held fixed leads to a solution in O(N2):
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
template <class T> void ckmin(T &a, const T &b) { a = min(a, b); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> dp(N, V<int>(N));
for (int i = N - 1; i >= 0; --i) {
dp.at(i).at(i) = A.at(i);
int k = i - 1;
for (int j = i + 1; j < N; ++j) {
while (k + 1 < j && dp.at(i).at(k + 1) <= dp.at(k + 2).at(j)) ++k;
dp[i][j] = INT_MAX;
ckmin(dp[i][j], dp.at(k + 1).at(j));
ckmin(dp[i][j], dp.at(i).at(k + 1));
++dp[i][j];
}
}
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
for (int j = i; j < N; ++j) {
ans += dp.at(i).at(j);
}
}
cout << ans << "\n";
}
Alternatively, finding k′ with binary search leads to a solution in O(N2logN).
Subtask 3: Similar to 262144, we can use binary lifting.
For each of v=1,2,3,…, I count the number of intervals with value at least v. The answer is the sum of this quantity over all v.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> with_val;
for (int i = 0; i < N; ++i) {
while (size(with_val) <= A[i]) with_val.emplace_back();
with_val.at(A[i]).push_back(i);
}
V<int> nex(N + 1);
iota(all(nex), 0);
int64_t ans = 0;
for (int v = 1;; ++v) {
if (nex[0] == N) {
cout << ans << "\n";
exit(0);
}
// add all intervals with value >= v
for (int i = 0; i <= N; ++i) ans += N - nex[i];
for (int i = 0; i <= N; ++i) nex[i] = nex[nex[i]];
if (v < size(with_val)) {
for (int i : with_val.at(v)) {
nex[i] = i + 1;
}
}
}
}
Subtask 4: For each i from 1 to N in increasing order, consider the values of all intervals with right endpoint i. Note that the value v of each such interval must satisfy v∈[Ai,Ai+?log2i?] due to A being sorted. Thus, it suffices to be able to compute for each v the minimum l such that dp[l][i]≤v. To do this, we maintain a partition of 1…i into contiguous subsequences such that every contiguous subsequence has value at most Ai and is leftwards-maximal (extending any subsequence one element to the left would cause its value to exceed Ai). When transitioning from i?1 to i, we merge every two consecutive contiguous subsequences Ai?Ai?1 times and then add contiguous subsequence [i,i] to the end of the partition.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
vector<int> A(N);
for (int &x : A) cin >> x;
assert(is_sorted(all(A)));
// left endpoints of each partition interval in decreasing order
deque<int> left_ends;
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
if (i) {
for (int v = A[i - 1]; v < A[i]; ++v) {
if (size(left_ends) == 1) break;
// merge every two consecutive intervals in partition
deque<int> n_left_ends;
for (int j = 0; j < (int)size(left_ends); ++j) {
if ((j & 1) || j + 1 == (int)size(left_ends)) {
n_left_ends.push_back(left_ends[j]);
}
}
swap(left_ends, n_left_ends);
}
}
left_ends.push_front(i); // add [i,i] to partition
int L = i + 1;
for (int v = A[i]; L; ++v) {
int next_L = left_ends.at(
min((int)size(left_ends) - 1, (1 << (v - A[i])) - 1));
ans += (int64_t)(L - next_L) * v;
L = next_L;
}
}
cout << ans << "\n";
}
Full Credit: Call an interval relevant if it is not possible to extend it to the left or to the right without increasing its value.
Claim: The number of relevant intervals is O(NlogN).
Proof: See the end of the analysis.
We'll compute the same quantities as in subtask 3, but this time, we'll transition from v?1 to v in time proportional to the number of relevant intervals with value v?1 plus the number of relevant intervals with value v, this will give us a solution in O(NlogN).
For a fixed value of v, say that an interval [l,r) is a segment if it contains no value greater than v, and min(Al?1,Ar)>v. Say that an interval is maximal with respect to v if it has value at most v and extending it to the left or right would cause its value to exceed v. Note that a maximal interval [l,r) must be relevant, and it must either have value equal to v or be a segment.
My code follows. ivals[i] stores all maximal intervals contained within the segment with left endpoint i. The following steps are used to transition from v?1 to v:
Apply halve on all segments containing more than one maximal interval (the left endpoints of every such segment are stored by active). Before, all intervals within the segment were maximal with respect to v?1. After, all intervals within the segment are maximal with respect to v.
Add a segment and a maximal interval of the form [i,i+1) for each i satisfying Ai=v, and then merge adjacent segments.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
V<V<int>> with_A;
for (int i = 0; i < N; ++i) {
cin >> A[i];
while ((int)size(with_A) <= A[i]) with_A.emplace_back();
with_A[A[i]].push_back(i);
}
// sum(l ... r)
auto sum_arith = [&](int64_t l, int64_t r) {
return (r + l) * (r - l + 1) / 2;
};
// total number of intervals covered by list of maximal intervals
auto contrib = [&](const list<pair<int, int>> &L) {
int64_t ret = 0;
for (auto it = begin(L);; ++it) {
auto [x, y] = *it;
if (next(it) == end(L)) {
ret += sum_arith(0, y - x);
break;
} else {
int next_x = next(it)->first;
ret += int64_t(next_x - x) * y - sum_arith(x, next_x - 1);
}
}
return ret;
};
int64_t num_at_least = (int64_t)N * (N + 1) / 2;
auto halve = [&](list<pair<int, int>> &L) {
if (size(L) <= 1) return;
num_at_least += contrib(L);
int max_so_far = -1;
list<pair<int, int>> n_L;
auto it = begin(L);
for (auto [x, y] : L) {
while (next(it) != end(L) && next(it)->first <= y) ++it;
if (it->second > max_so_far) {
n_L.push_back({x, max_so_far = it->second});
}
}
swap(L, n_L);
num_at_least -= contrib(L);
};
// doubly linked list to maintain segments
V<int> pre(N + 1);
iota(all(pre), 0);
V<int> nex = pre;
int64_t ans = 0;
V<int> active; // active segments
// maximal intervals within each segment
V<list<pair<int, int>>> ivals(N + 1);
for (int v = 1; num_at_least; ++v) {
ans += num_at_least; // # intervals with value >= v
V<int> n_active;
for (int l : active) {
halve(ivals[l]);
if (size(ivals[l]) > 1) n_active.push_back(l);
}
if (v < (int)size(with_A)) {
for (int i : with_A[v]) {
int l = pre[i], r = nex[i + 1];
bool should_add = size(ivals[l]) <= 1;
pre[i] = nex[i + 1] = -1;
nex[l] = r, pre[r] = l;
ivals[l].push_back({i, i + 1});
--num_at_least;
ivals[l].splice(end(ivals[l]), ivals[i + 1]);
should_add &= size(ivals[l]) > 1;
if (should_add) n_active.push_back(l);
}
}
swap(active, n_active);
}
cout << ans << "\n";
}
Proof of Claim: Let f(N) denote the maximum possible number of relevant subarrays for a sequence of size N. We can show that f(N)≤O(logN!)≤O(NlogN). This upper bound can be attained when all elements of the input sequence are equal.
The idea is to consider a Cartesian tree of a. Specifically, suppose that one of the maximum elements of a is located at position p (1≤p≤N). Then
f(N)≤f(p?1)+f(N?p)+#(relevant intervals containing p).
WLOG we may assume 2p≤N.
Lemma:
#(relevant intervals containing p)≤O(plog(Np))
Proof of Lemma: We can in fact show that
#(relevant intervals containing p with value ap+k)≤min(p,2k):0≤k≤?log2N?.
The p comes from all relevant intervals with a fixed value having distinct left endpoints and the 2k comes from the fact that to generate a relevant interval of value ap+k+1 containing p, you must start with a relevant interval of value ap+k and choose to extend it either to the right or to the left.
To finish, observe that the summation ∑?log2N?k=0#(relevant intervals containing p with value ap+k) is dominated by the terms satisfying log2p≤k≤?log2N?. ■
Since O(plogNp)≤O(log(Np))≤O(logN!(p?1)!(N?p)!)
f(N)≤f(p?1)+f(N?p)+#(relevant intervals containing p)≤O(log(p?1)!)+O(log(N?p)!)+O(logN!?log(p?1)!?log(N?p)!)≤O(logN!).
Here is an alternative approach by Danny Mittal that uses both the idea from the subtask 4 solution and the Cartesian tree. He repeatedly finds the index of the rightmost maximum amida1…mid?1amid+1…NamidO(NlogN)
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Revisited262144Array {
static int n;
static int[] xs;
static int[] left;
static int[] right;
static int[] forward;
static int[] reverse;
static long answer = 0;
public static int reduceForward(int start, int length, int lgFactor) {
if (lgFactor == 0) {
return length;
}
int factor = 1 << Math.min(lgFactor, 20);
int j = start;
for (int k = start + factor - 1; k < start + length - 1; k += factor) {
forward[j++] = forward[k];
}
forward[j++] = forward[start + length - 1];
return j - start;
}
public static void reduceReverse(int start, int length, int lgFactor) {
if (lgFactor == 0) {
return;
}
int factor = 1 << Math.min(lgFactor, 20);
if (length > factor) {
int j = start + 1;
for (int k = start + 1 + ((length - factor - 1) % factor); k < start + length; k += factor) {
reverse[j++] = reverse[k];
}
}
}
public static int funStuff(int from, int mid, int to, int riseTo) {
if (from > to) {
return 0;
}
int leftLength = funStuff(from, left[mid], mid - 1, xs[mid]);
int rightLength = funStuff(mid + 1, right[mid], to, xs[mid]);
int leftStart = from;
int rightStart = mid + 1;
int last = mid - 1;
for (int j = 1; j <= rightLength + 1; j++) {
int frontier = j > 1 ? forward[rightStart + (j - 2)] : mid;
long weight = frontier - last;
last = frontier;
int lastInside = mid + 1;
int leftLast = leftLength == 0 ? mid : reverse[leftStart];
for (int d = 0; d <= 18 && lastInside > leftLast; d++) {
if (1 << d >= j) {
int frontierInside;
if (1 << d == j) {
frontierInside = mid;
} else if (1 << d <= j + leftLength) {
frontierInside = reverse[leftStart + leftLength + j - (1 << d)];
} else {
frontierInside = leftLast;
}
long weightInside = lastInside - frontierInside;
lastInside = frontierInside;
answer += weight * weightInside * ((long) (xs[mid] + d));
}
}
}
forward[leftStart + leftLength] = mid;
System.arraycopy(forward, rightStart, forward, leftStart + leftLength + 1, rightLength);
reverse[leftStart + leftLength] = mid;
System.arraycopy(reverse, rightStart, reverse, leftStart + leftLength + 1, rightLength);
int length = reduceForward(leftStart, leftLength + 1 + rightLength, riseTo - xs[mid]);
reduceReverse(leftStart, leftLength + 1 + rightLength, riseTo - xs[mid]);
return length;
}
public static void main(String[] args) throws IOException {
BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
n = Integer.parseInt(in.readLine());
StringTokenizer tokenizer = new StringTokenizer(in.readLine());
xs = new int[n];
for (int j = 0; j < n; j++) {
xs[j] = Integer.parseInt(tokenizer.nextToken());
}
left = new int[n];
right = new int[n];
ArrayDeque<Integer> stack = new ArrayDeque<>();
for (int j = 0; j < n; j++) {
while (!stack.isEmpty() && xs[stack.peek()] <= xs[j]) {
left[j] = stack.pop();
}
if (!stack.isEmpty()) {
right[stack.peek()] = j;
}
stack.push(j);
}
while (stack.size() > 1) {
stack.pop();
}
forward = new int[n];
reverse = new int[n];
funStuff(0, stack.pop(), n - 1, Integer.MAX_VALUE);
System.out.println(answer);
}
}