斜率优化训练记录

前言

斜率优化一般用于优化dp的转移,借着训练斜率优化的相关问题来提升一些DP思维。选择老学长留下的专题场来练手,由于该场题数较多,以及个人不太愿意长时间进行单一专题训练,因此开此文来记录断续的训练结果和心得。

记录

0x01

由一道简单入门题玩具装箱开头,题意和思路比较简单就不讲了。
代码

SRE实战 互联网时代守护先锋,助力企业售后服务体系运筹帷幄!一键直达领取阿里云限量特价优惠。
#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e5+10,INF=0x3f3f3f3f,mod=1e9+7;
ll f[maxn],sum[maxn],a[maxn],q[maxn],h,t;
inline double K(int i,int j)
{
    double dy=a[j]*a[j]+f[j]-a[i]*a[i]-f[i],dx=a[j]-a[i];
    return dy/dx;
}
int main()
{
    int n,l;
    cin>>n>>l;
    for (int i=1;i<=n;++i)
    {
        int x;
        scanf("%d",&x);
        sum[i]=sum[i-1]+x;
        a[i]=i+1+sum[i];
    }
    a[0]=1;
    h=t=1;
    for (int i=1;i<=n;++i)
    {
        while (h<t&&K(q[h],q[h+1])<2*(sum[i]+i-l))
            h++;
        int j=q[h];
        ll tmp=i-j-1+sum[i]-sum[j]-l;
        f[i]=f[j]+tmp*tmp;
        while (h<t&&K(q[t-1],q[t])>K(q[t-1],i))
            t--;
        q[++t]=i;
    }
    cout<<f[n];
    return 0;
}

0x02

小A与最大字段和还是入门题。题意见原题面,比较简单。思路的话,维护一个普通前缀和与一个梯形前缀和,然后与上题一样通过变形式子写成直线截距式。唯一的不同是,该题Ai可能是负数,因此直线的斜率不能保证单调变化,因此选取最优点时需要二分队列找到首个往后比直线斜率小的点。(上一题由于直线斜率单调增加,因此每次选最优点只要把队首比直线斜率小的点都出队即可)。然后这题要最大值,所以入队时要维护一个上凸壳。
代码

#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=2e5+10,INF=0x3f3f3f3f,mod=1e9+7;
ll s1[maxn],s2[maxn],q[maxn],h,t;
inline double K(int i,int j)
{
    double dy=j*s1[j]-s2[j]-i*s1[i]+s2[i],dx=j-i;
    return dy/dx;
}
int find(double k)
{
    int l=h,r=t,ans;
    while (l<=r)
    {
        int m=(l+r)>>1;
        if (m==t)
        {
            ans=t;
            break;
        }
        if (K(q[m],q[m+1])<=k)
            ans=m,r=m-1;
        else
            l=m+1;
    }
    return ans;
}
int main()
{
    int n;
    cin>>n;
    for (int i=1;i<=n;++i)
    {
        int x;
        scanf("%d",&x);
        s1[i]=s1[i-1]+x;
        s2[i]=s2[i-1]+i*x;
    }
    ll ans=-1e18;
    h=t=1;
    for (int i=1;i<=n;++i)
    {
        int j=q[find(s1[i])];
        ans=max(ans,s2[i]-s2[j]-j*(s1[i]-s1[j]));
        while (h<t&&K(i,q[t-1])>K(q[t],q[t-1]))
            --t;
        q[++t]=i;
    }
    cout<<ans;
    return 0;
}

0x03

HDU2993这题思路不是很难,但是不知道为什么卡读入,相当恶心,比较low的IO优化还过不去,我T了十几发之后用了学长的fread读入的板子才过的。因此极其不建议大家做,脑内AC就可以了。
题意是给个长为n的数列,找一个长度不小于k且平均值最大的子段。换句话说就是找所有点(i,sum[i])中斜率最大的两个点的斜率。
不难想到,我们应该维护一个下凸壳(因为其上方的点肯定无法与之后的点构成更优的解),然后一般可以二分找最优点,但是由这题的性质可以发现,因为sum[i]是递增的,因此每次找到更优的点时,其前面的点可以直接舍弃(它们不可能与后面新加入的点构成更大的斜率了)。因此复杂度可以做到O(N)。
代码(再次强调,不建议做)

#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<endl
#define sz(x) int(x.size())
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define All(x) x.begin(),x.end()
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e5+10,INF=0x3f3f3f3f,mod=1e9+7;
struct FastIO {
    static const int S = 1310720;
    int wpos;
    char wbuf[S];
    FastIO() : wpos(0) { }
    inline int xchar() {
        static char buf[S];
        static int len = 0, pos = 0;
        if (pos == len) pos = 0, len = fread(buf, 1, S, stdin);
        if (pos == len) return -1;
        return buf[pos++];
    }
    inline int xint() {
        int c = xchar(), x = 0, s = 1;
        if (c==-1)
            return -1;
        while (c <= 32) c = xchar();
        if (c == '-') s = -1, c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x * s;
    }
    ~FastIO() {
        if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
    }
} io;
ll sum[maxn];
int q[maxn],h,t;
inline double K(int i,int j)
{
    return double(sum[i]-sum[j])/(i-j);
}
int main()
{
    int n,k;
    while (n=io.xint(),k=io.xint()) 
    {
        if (n==-1)
            break;
        for (int i=1;i<=n;++i)
        {
            int x=io.xint();
            sum[i]=sum[i-1]+x;
        }
        h=0;
        t=-1;
        double ans=-1;
        for (int i=k;i<=n;++i)
        {
            while (h<t&&K(q[t],q[t-1])>K(i-k,q[t-1]))
                --t;
            q[++t]=i-k;
            while (h<t&&K(q[h],i)<K(q[h+1],i))
                ++h;
            ans=max(ans,K(q[h],i));
        }
        printf("%.2lf\n",ans);
    }
    return 0;
}

0x04

HDU 3045,注意一下转移的合法性,其余的就是常规的斜率优化DP。特别注意斜率的比较上不要写错(手贱做差时取了绝对值,WA了十几发),为了避免精度误差,建议写成乘法形式来比较。
代码

#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"\n"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll a[maxn],sum[maxn],f[maxn];
int q[maxn],h,t;
ll dy(int i,int j)
{
    return f[i]-sum[i]+i*a[i+1]-f[j]+sum[j]-j*a[j+1];
}
ll dx(int i,int j)
{
    return a[i+1]-a[j+1];
}
ll getf(int i,int j)
{
    return f[i]+sum[j]-sum[i]-a[i+1]*(j-i);
}
int main()
{
    int n,k;
    while (scanf("%d%d",&n,&k)!=EOF)
    {
        for (int i=1;i<=n;++i)
            scanf("%I64d",&a[i]);
        sort(a+1,a+1+n);
        for (int i=1;i<=n;++i)
            sum[i]=sum[i-1]+a[i];
        h=t=0;
        for (int i=k;i<=n;++i)
        {
            int j=i-k;
            if (j>=k)
            {
                while (h<t&&dy(j,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(j,q[t-1]))
                    --t;
                q[++t]=j;
            }
            while (h<t&&dy(q[h+1],q[h])<=i*dx(q[h+1],q[h]))
                ++h;
            f[i]=getf(q[h],i);
        }
        printf("%I64d\n",f[n]);
    }
    return 0;
}

0x05

POJ 1180,挺好的一题。题意比较繁琐,建议直接看原题面。
这题用斜率优化的部分也很常规,比较有技巧的是如何把枚举分组数的这一维优化掉,使得复杂度降到O(N)。可以发现,要知道当前这个组的结束时间,与之前分了几个组有关,这样的话转移时就必须枚举了。但是换个角度,我们可以考虑每个分组对后面分组代价的影响,提前计算对当前组造成的s对全局答案的贡献,这样转移时就不需要考虑前面分组对当前的影响了(因为已经在前面计算过了)。因此转移方程可以写成f(j)=min{f(i)+(tsum[j]+s)*(fsum[j]-fsum[i])+s*(fsum[n]-fsum[j])}(0<=i<j),有了这个转移式,剩下的就是通过斜率优化,把i的枚举优化掉,达到线性复杂度。

#include<iostream>
#include<cstdio>
#include<algorithm>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"\n"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
const int maxn=1e5+10,mod=1e9+7,INF=0x3f3f3f3f;
ll fsum[maxn],tsum[maxn],dp[maxn];
int q[maxn],h,t,s,n;
inline ll dy(int i,int j)
{
    return dp[i]-dp[j];
}
inline ll dx(int i,int j)
{
    return fsum[i]-fsum[j];
}
inline ll getv(int i,int j)
{
    return dp[i]+(tsum[j]+s)*(fsum[j]-fsum[i])+s*(fsum[n]-fsum[j]);
}
int main()
{
    while (cin>>n)
    {
        cin>>s;
        for (int i=1;i<=n;++i)
        {
            scanf("%lld%lld",&tsum[i],&fsum [i]);
            tsum[i]+=tsum[i-1],fsum[i]+=fsum[i-1];
        }
        t=h=0;
        for (int j=1;j<=n;++j)
        {
            while (h<t&&dy(q[h+1],q[h])<=(tsum[j]+s)*dx(q[h+1],q[h]))
                ++h;
            dp[j]=getv(q[h],j);
            while (h<t&&dy(j,q[t-1])*dx(q[t],q[t-1])<=dy(q[t],q[t-1])*dx(j,q[t-1]))
                --t;
            q[++t]=j;
        }
        cout<<dp[n]<<endl;
    }
    return 0;
}

0x06

HDU 3480,相比于第一题玩具装箱,多了分组数的限制,其他是一模一样的。出于空间的限制,我们先枚举分组数,这样每次队列就可以清空复用。注意分组数为k时的答案都是由分组数k-1时的答案转移而来,注意边界点,细节见代码

#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"\n"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=1e4+10,mod=1e9+7,INF=0x3f3f3f3f;
ll a[maxn],f[maxn>>1][maxn];
int q[maxn],h,t;
inline ll dy(int i,int j,int k)
{
    return f[k][i]+a[i+1]*a[i+1]-f[k][j]-a[j+1]*a[j+1];
}
inline ll dx(int i,int j)
{
    return a[i+1]-a[j+1];
}
inline ll getf(int i,int j,int k)
{
    return f[k-1][i]+(a[j]-a[i+1])*(a[j]-a[i+1]);
}
int main()
{
    int T;
    cin>>T;
    for (int cas=1;cas<=T;++cas)
    {
        int n,m;
        scanf("%d%d",&n,&m);
        for (int i=1;i<=n;++i)
            scanf("%lld",&a[i]);
        sort(a+1,a+1+n);
        for (int j=1;j<=n;++j)
            f[1][j]=(a[j]-a[1])*(a[j]-a[1]);
        for (int k=2;k<=m;++k)
        {
            h=t=0;
            q[0]=k-1;
            for (int j=k;j<=n;++j)
            {
                while (h<t&&dy(q[h+1],q[h],k-1)<=2*a[j]*dx(q[h+1],q[h]))
                    ++h;
                f[k][j]=getf(q[h],j,k);
                while (h<t&&dy(j,q[t-1],k-1)*dx(q[t],q[t-1])<=dy(q[t],q[t-1],k-1)*dx(j,q[t-1]))
                    --t;
                q[++t]=j;
            }
        }
        printf("Case %d: %lld\n",cas,m>n?0ll:f[m][n]);
    }
    return 0;
}
扫码关注我们
微信号:SRE实战
拒绝背锅 运筹帷幄