题目链接

AcWing 1236.递增三元组

引用

z林深时见鹿 AcWing 1236. 递增三元组( 二分解法 )

macat AcWing 1236. 递增三元组 (二分+双指针+前缀和)

冷丁Hacker AcWing 1236. 递增三元组(求赞谢谢兄弟)

田所浩二 AcWing 1236. 递增三元组(模板二分的应用)

小呆呆 AcWing 1236. 递增三元组

题目描述

给定三个整数数组

A=[A1,A2,…AN],
B=[B1,B2,…BN],
C=[C1,C2,…CN],

请你统计有多少个三元组 (i,j,k) 满足:

1i,j,kN1≤i,j,k≤N
Ai<Bj<CkA_i<B_j<C_k

样例

输入格式

第一行包含一个整数 N。

第二行包含 N 个整数 A1,A2,…AN。

第三行包含 N 个整数 B1,B2,…BN。

第四行包含 N 个整数 C1,C2,…CN。

输出格式

一个整数表示答案。

数据范围

1N105,1≤N≤105,
0Ai,Bi,Ci1050≤A_i,B_i,C_i≤105

输入样例:

3
1 1 1
2 2 2
3 3 3

输出样例:

27

算法1

(前缀和) O(n)O(n)

因为我们分别从A,B,C中取出一个数,构成 Ai<Bj<CkA_i < B_j < C_k ,如果只是单纯的朴素算法(暴力),那么就是 O(n3)O(n^3) 的时间复杂度必定超时。那么能不能找出其中的关系呢

其实细心的你肯定发现了 BjB_j 作为中间数字和其他两位数的关系,我们只要找出A中有多少了数小于当前的B,C中有多少个数大于B,最后相乘累加即可,相乘是因为排列问题

前缀和:当y总说出可以用前缀和方法做的时候我还在疑惑该怎么做,被他这么一讲,好像对于前缀和又有了新的认识,之前只是停留在对于数组的前缀和计算,但是这题用到的不是数组本身的前缀和,而是数组中的数在这个数组里出现次数的前缀和
怎么理解呢?我们计算A、C中数字出现的次数,然后对他们两个求前缀和,接着只要枚举B,算一下A中小于 BjB_j 的数的个数和C中大于 CkC_k 的数的个数的乘积,然后累加就行

时间复杂度 O(n)O(n)

前缀和的时间复杂度是 O(n)O(n)

C++ 代码

#include <cstdio>
#include <cstring>

using namespace std;

const int N = 1e5 + 10;

int n;

int a[N] , b[N] , c[N];
int as[N] , cs[N];
int cnt[N] , s[N];

int main()
{
    scanf("%d" , &n);

    //因为数据范围是从零开始,计算前缀和时我们习惯下标从1开始,所以数组中的数都往前移
    for(int i = 0; i < n; i ++) scanf("%d" , &a[i]) , a[i] ++;
    for(int i = 0; i < n; i ++) scanf("%d" , &b[i]) , b[i] ++;
    for(int i = 0; i < n; i ++) scanf("%d" , &c[i]) , c[i] ++;

    //计算as[]
    for(int i = 0; i < n; i ++) cnt[a[i]] ++;
    for(int i = 1; i < N; i ++) s[i] = s[i - 1] + cnt[i];
    //b[i]- 1是因为b增加了1
    for(int i = 0; i < n; i ++) as[i] = s[b[i] - 1];

    //清空cnt[] 和 s[]
    memset(cnt , 0 , sizeof cnt);
    memset(s , 0 , sizeof s);

    //计算过cs[]
    for(int i = 0; i < n; i ++) cnt[c[i]] ++;
    for(int i = 1; i < N; i ++) s[i] = s[i - 1] + cnt[i];
    //本来b[i]应该减1,但是因为已经加过了,就不需要了,s[N - 1] - s[b[i]] ==> s[r] - s[l - 1]    s[l ~ r] = s[r] - s[l - 1]
    for(int i = 0; i < n; i ++) cs[i] = s[N - 1] - s[b[i]];

    long long res = 0;
    for(int i = 0; i < n; i ++)
        res += (long long)as[i] * cs[i];
    
    printf("%lld" , res);

    return 0;
}

算法2

(sort+二分) O(nlogn)O(nlog_n)

sort+二分:首先我们先将A、C数组排序,B排不排都无所谓。然后遍历B,二分找出A中小于 BjB_j 的数的下标,判断一下这个数是不是小于 BjB_j ,如果不是小于的话就说明A中最小的数都大于它,故小于 BjB_j
的数为零。同理对于C数组也一样,然后将得到的个数相乘再累加即可

时间复杂度 O(nlogn)O(nlog_n)

sort+二分的时间复杂度是 O(nlogn)O(nlog_n)

C++ 代码

#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 100010;

int n;
int a[N] , b[N] , c[N];

int main()
{
    scanf("%d" , &n);
    
    for(int i = 0; i < n; i ++) scanf("%d" , &a[i]);
    for(int i = 0; i < n; i ++) scanf("%d" , &b[i]);
    for(int i = 0; i < n; i ++) scanf("%d" , &c[i]);
    
    sort(a , a + n) , sort(c , c + n);
    
    long long res = 0;
    for(int i = 0; i < n; i ++)
    {
        int l = 0 , r = n - 1 , mid;
        long long x , y;
        while(l < r)
        {
            mid = l + r + 1 >> 1;
            if(a[mid] < b[i]) l = mid;
            else r = mid - 1;
        }
        if(a[l] < b[i]) x = l + 1;
        else x = 0;
        
        l = 0 , r = n - 1;
        while(l < r)
        {
            mid = l + r >> 1;
            if(c[mid] > b[i]) r = mid;
            else l = mid + 1;
        }
        if(c[l] > b[i]) y = n - l;
        else y = 0;
        
        res += x * y;
    }
    
    printf("%lld" , res);
    
    return 0;
}

算法3

(双指针) O(nlogn)O(nlog_n)

双指针:首先双指针算法运用的场景必须具备单调性,所以我们同二分一样要先排序,只不过查找的时候用双指针算法,因为查找最多是n次,所以整个算法的时间复杂度是 O(nlogn)O(nlog_n),因为双指针是不回溯的,所以它的时间复杂度才会是 O(n)O(n)。因此我们计算时应该把三个数组都排序一遍,然后用两个变量计数就行了

时间复杂度 O(nlogn)O(nlog_n)

此时整个算法的时间复杂度是 O(nlogn)O(nlog_n) \Longleftrightarrow O(n+nlogn)O(n + nlog_n)

C++ 代码


#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 100010;

int n;
int a[N] , b[N] , c[N];

int main()
{
    scanf("%d" , &n);

    for(int i = 0; i < n; i ++) scanf("%d" , &a[i]);
    for(int i = 0; i < n; i ++) scanf("%d" , &b[i]);
    for(int i = 0; i < n; i ++) scanf("%d" , &c[i]); 

    sort(a , a + n) , sort(b , b + n) , sort(c , c + n);

    long long res = 0;
    
    //因为指针不能回溯,一旦回溯了时间复杂度就是n三方了,所以将变量定义在循环外
    int j = 0 , k = 0;
    for(int i = 0; i < n; i ++)
    {   
        //指针停下来时一定在不满足条件的那个数上
        while(j < n && a[j] < b[i]) j ++;
        //取等号的原因是因为停下来的时候是不满足大于的,第k个数可能小于也可能等于,这样就会导致可能当前多算了一个,因此指针停下来时一定要在满足条件的那个数上,下一次的时候就能继续了
        while(k < n && c[k] <= b[i]) k ++;
        res += (long long) j * (n - k);
    }

    printf("%lld" , res);
    
    return 0;
}