思路:
首先我们计算出solve(m):中位数大于等于m的方案数,那么最后答案就是solve(m) - solve(m+1)
那么怎么计算sovle(m)呢?
对于一个区间[l,r],如果它的中位数大于等于m,那么这个区间中 (大于等于m的数的个数) > (小于m的数的个数)
如果记a[i]大于等于m为+1,小于m 为 -1,即 sum(l, r) > 0
我们枚举右端点 i ,并且同时计算sum(1, i) ,那么对于这个右端点,我们只要找到之前的 sum 中 < sum(1, i)的个数(左端点的个数),这个可以用树状数组维护
但是我们有一个O(n)的方法求,用了类似莫队的方法,记s[i]为之前的sum为i的个数,add为上一个小于sum(1, i-1)的个数,对于当前的sum,
如果它要加1,add += s[sum], sum++
如果它要减1,sum --, add -= s[sum]
这样得出的add就是当前的小于sum(1, i)的个数
代码:
#includeusing namespace std;#define fi first#define se second#define pi acos(-1.0)#define LL long long//#define mp make_pair#define pb push_back#define ls rt<<1, l, m#define rs rt<<1|1, m+1, r#define ULL unsigned LL#define pll pair #define pii pair #define piii pair #define mem(a, b) memset(a, b, sizeof(a))#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);//headconst int N = 2e5 + 5;int a[N], cnt[N*2], n, m;LL solve(int m) { int s = n; mem(cnt, 0); cnt[s] = 1; LL add = 0, ans = 0; for (int i = 1; i <= n; i++) { if(a[i] >= m) add += cnt[s], s++; else s--, add -= cnt[s]; cnt[s]++; ans += add; } return ans;}int main() { scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); printf("%lld\n", solve(m) - solve(m+1)); return 0;}