拉格朗日插值

发布于 2020-10-10  64 次阅读


Lagrange插值

求拉格朗日插值需要注意的几个点,求f(x)之前先确定好需要几个点的坐标,求得所有的y[i],这个只需要求一次,pre[]和suf[]在每求一次f(x)都需要计算一次,然后每一项y[i]需乘以pre[i-1]*suf[i+1]*inv(fac[i-1])*inv(fac[n+1-i]),每一项乘完之后要去mod,最后在乘以(-1)^{n+1-i}确定符号,最后加到答案里

inline ll f(ll x){
    ll ans = 0;
    pre[0] = suf[k+2+1] = 1;
    for (int i=1; i<=k+2; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+2; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+2; i++){
        int flag = 1;
        if ((k+2-i)&1)  flag = -1;
        ans = (ans+y[i]*pre[i-1]%mod*suf[i+1]%mod*inv[i-1]%mod*inv[k+2-i]%mod*flag)%mod;
    }
    return (ans+mod)%mod;
}

对于一个n次多项式,可以用系数表示法表示,这样需要n+1个系数;也可以用点值表示法表示,需要n+1个横坐标不同的点。也就是说,n+1个点可以唯一确定一个最高次不超过n次的多项式,可以用n+1个点来求得一个多项式。

给定n个点,确定一个多项式。
若要根据这n个点求得多项式的所有系数,则需要时间复杂度为O(n^3)的高斯消元法来解决,且存在精度问题。
利用拉格朗日插值法可以在O(n^2)的复杂度求得多项式在某个位置的值f(x).设已经给出了n+1个点(xi, yi).
公式:f(x) = \sum_{i=0}^n y[i]*(\prod_{j=0,j \neq i}^n\frac{x-x[j]}{x[i]-x[j]})

当这n+1个点的横坐标毫无规律时,只能O(n^2)求,对于每个y[i],先把后面乘积符号的分子全都乘起来,分母也乘起来,然后分母去逆元即可。
当n+1个点的横坐标连续时,假设是[1, n+1],则可以优化成O(n)来处理。
f(x) = \sum_{i=0}^ny[i]*(\prod_{j=0,j \neq i}^n\frac{x-x[j]}{x[i]-x[j]}) = \sum_{i=0}^ny[i]*(\prod_{j=0, j \neq i}^n\frac{x-j}{i-j}) =\sum_{i=0}^ny[i]*(\frac{pre[i-1]*suf[i+1]}{fac[i-1]*fac[n-i+1]*(-1)^{n-i+1}})

pre[i] = pre[i-1]*(x-i), pre[0] = 1
suf[i] = suf[i+1]*(x-i), suf[n+2] = 1

pre[]和suf[]都可以线性求得,阶乘fac[]提前维护即可。

求拉格朗日插值需要注意的几个点,求f(x)之前先确定好需要几个点的坐标,求得所有的y[i],这个只需要求一次,pre[]和suf[]在每求一次f(x)都需要计算一次,然后每一项y[i]需乘以pre[i-1]*suf[i+1]*inv(fac[i-1])*inv(fac[n+1-i]),每一项乘完之后要去mod,最后在乘以(-1)^{n+1-i}确定符号,最后加到答案里

一个经典应用:求式子f(x) = \sum_{i=0}^x i^k.
这是一个k+1次多项式,证明略。所以求得前k+2个点后,可以O(k)求得任意一个f(x).当前面多一个求和符号时,最高次+1,多几个就+几。

P4781 拉格朗日插值

普通拉格朗日插值,O(n^2)求法。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int mod = 998244353, maxn = 2003;

ll n, k;
ll x[maxn], y[maxn];
inline void extgcd(ll a, ll b, ll &x, ll &y){
    if (b==0){
        x = 1, y = 0;
        return;
    }
    extgcd(b, a%b, y, x);
    y -= x*(a/b);
}
inline ll inv(ll a){
    ll x, y;
    a = (a%mod+mod)%mod;
    extgcd(a, mod, x, y);
    return (x%mod+mod)%mod;
}
int main(){
    ll ans = 0;
    cin >> n >> k;
    for (int i=1; i<=n; i++)    cin >> x[i] >> y[i];
    for (int i=1; i<=n; i++){
        for (int j=1; j<=n; j++)    if (j!=i)
            y[i] = y[i]*(k-x[j])%mod;
        ll tmp = 1;
        for (int j=1; j<=n; j++)    if (j!=i)
            tmp = tmp*(x[i]-x[j])%mod;
        ans += y[i]*inv(tmp)%mod;
    }
    ans = (ans%mod+mod)%mod;
    cout << ans;
    return 0;
}

CF622F The Sum of the k-th Powers

求式子: \sum_{i=0}^n i^k.

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const ll maxn = 1e6+5, mod = 1e9+7;
int n, k;
ll fac[maxn], pre[maxn], suf[maxn], y[maxn], inv[maxn];
ll pow_mod(ll a, ll b){
    a %= mod, b %= mod-1;
    if (!a) return 0;
    ll ans = 1;
    while (b){
        if (b&1)    ans = ans*a%mod;
        b >>= 1;
        a = a*a%mod;
    }
    return ans;
}
inline ll f(ll x){
    ll ans = 0;
    pre[0] = suf[k+2+1] = 1;
    for (int i=1; i<=k+2; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+2; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+2; i++){
        int flag = 1;
        if ((k+2-i)&1)  flag = -1;
        ans = (ans+y[i]*pre[i-1]%mod*suf[i+1]%mod*inv[i-1]%mod*inv[k+2-i]%mod*flag)%mod;
    }
    return (ans+mod)%mod;
}
int main(){
    cin >> n >> k;
    for (int i=1; i<=k+2; i++)  y[i] = (y[i-1]+pow_mod(i, k))%mod;
    fac[0] = 1, inv[0] = 1;
    for (int i=1; i<=k+2; i++)  fac[i] = fac[i-1]*i%mod, inv[i] = pow_mod(fac[i], mod-2);
    cout << f(n);
    return 0;
}

P4593 教科书般的亵渎

看题都看了半天。。懒得写了。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const ll mod = 1e9+7;
ll n, m, k;
ll a[55];
ll y[55], pre[55], suf[55], fac[55];
inline ll pow_mod(ll a, ll b){
    a %= mod;
    if (!a) return 0;
    b %= mod-1;
    ll ans = 1;
    while (b){
        if (b&1)    ans = ans*a%mod;
        b >>= 1;
        a = a*a%mod;
    }
    return ans;
}
inline ll inv(ll a){
    a = (a%mod+mod)%mod;
    return pow_mod(a, mod-2);
}
void init(){
    y[1] = 1;
    for (int i=2; i<=k+2; i++)  y[i] = (y[i-1]+pow_mod(i, k))%mod;
    return;
}
ll f(ll x){
    x %= mod;
    ll ans = 0;
    pre[0] = suf[k+2+1] = 1;
    for (int i=1; i<=k+2; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+2; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+2; i++){
        ans = (ans+y[i]*pre[i-1]%mod*suf[i+1]%mod*inv(fac[i-1])%mod*inv(fac[k+2-i])%mod*(((k+2-i)&1)?-1:1))%mod;
    }
    return (ans+mod)%mod;
}
int main(){
    int t;
    fac[0] = fac[1] = 1;
    for (int i=2; i<=53; i++)   fac[i] = fac[i-1]*i%mod;
    scanf ("%d", &t);
    while (t--){
        ll ans = 0;
        scanf ("%lld %lld", &n, &m);    k = m+1;
        for (int i=1; i<=m; i++)    scanf ("%lld", a+i);
        init();
        sort (a+1, a+m+1);
        for (int i=1; i<=m; i++){
            ans = (ans+f(n))%mod;
            for (int j=i; j<=m; j++)    ans = (ans-pow_mod(a[j], k))%mod;
            for (int j=i+1; j<=m; j++)    a[j] -= a[i];
            n -= a[i];
        }
        ans = (ans+f(n))%mod;
        printf ("%lld\n", (ans%mod+mod)%mod);
    }
    return 0;
}

LOJ6024 XLKxc

给定k, a, n, d, 求式子:\sum_{i=0}^n \sum_{j=1}^{a+id} \sum_{l=1}^{j} l^k (\%p)
对式子进行转化:\sum_{i=0}^n \sum_{j=1}^{a+id} \sum_{l=1}^{j} l^k (\%p) = \sum_{i=0}^n \sum_{j=1}^{a+id} f(j) (\%p) = \sum_{i=0}^{n}g(a+id) = h(n)
根据多项式的定义有:
f(x) = \sum_{i=1}^{x}i^k是k+1次多项式。
g(x) = \sum_{j=1}^{x}f(x)是k+2次多项式
h(x) = \sum_{i=0}^{x}g(x)是k+3次多项式
使用拉格朗日插值法可在O(k^2)复杂度内解决问题。

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const ll mod = 1234567891;

ll k, a, n, d;
ll fac[150], suf[150], pre[150];
ll y[150], y2[150], y3[150];
inline ll pow_mod(ll a, ll b){
    a = (a%mod+mod)%mod;
    if (!a) return 0;
    b %= mod-1;
    ll ans = 1;
    while (b){
        if (b&1)    ans = ans*a%mod;
        b >>= 1;
        a = a*a%mod;
    }
    return ans;
}
inline ll inv(ll a){
    return pow_mod(a, mod-2);
}

inline ll f(ll x){
    ll ans = 0;
    x %= mod;
    pre[0] = suf[k+2+1] = 1;
    for (int i=1; i<=k+2; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+2; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+2; i++){
        int flag = 1;
        if ((k+2-i)&1)    flag = -1;
        ans = (ans+y[i]*pre[i-1]%mod*suf[i+1]%mod*inv(fac[i-1])%mod*inv(fac[k+2-i])%mod*flag)%mod;
    }
    return (ans%mod+mod)%mod;
}
inline ll g(ll x){
    ll ans = 0;
    x %= mod;
    pre[0] = suf[k+3+1] = 1;
    for (int i=1; i<=k+3; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+3; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+3; i++){
        int flag = 1;
        if ((k+3-i)&1)  flag = -1;
        ans = (ans+y2[i]*pre[i-1]%mod*suf[i+1]%mod*inv(fac[i-1])%mod*inv(fac[k+3-i])%mod*flag)%mod;
    }
    return (ans%mod+mod)%mod;
}
inline ll h(ll x){
    ll ans = 0;
    x %= mod;
    pre[0] = suf[k+4+1] = 1;
    for (int i=1; i<=k+4; i++)  pre[i] = pre[i-1]*(x-i)%mod;
    for (int i=k+4; i>=1; i--)  suf[i] = suf[i+1]*(x-i)%mod;
    for (int i=1; i<=k+4; i++){
        int flag = 1;
        if ((k+4-i)&1)  flag = -1;
        ans = (ans+y3[i]*pre[i-1]%mod*suf[i+1]%mod*inv(fac[i-1])%mod*inv(fac[k+4-i])%mod*flag)%mod;
    }
    return (ans%mod+mod)%mod;
}
int main(){
    int t;
    fac[0] = 1;
    for (int i=1; i<=130; i++)  fac[i] = fac[i-1]*i%mod;
    scanf ("%d", &t);
    while (t--){
        scanf ("%lld %lld %lld %lld", &k, &a, &n, &d);
        y[0] = 0, y2[0] = 0, y3[0] = 0;
        for (int i=1; i<=k+2; i++)  y[i] = (y[i-1]+pow_mod(i, k))%mod;           //y只要求1次即可,之后会重复用到
        for (int i=1; i<=k+3; i++)  y2[i] = (y2[i-1]+f(i))%mod;
        y3[0] = g(a);
        for (int i=1; i<=k+4; i++)  y3[i] = (y3[i-1]+g(a+i*d))%mod;
        printf ("%lld\n", h(n));
    }
    return 0;
}

ps: 对于一个多项式f(x),前面加一个求和符号后,还是一个多项式,最高次+1.


一沙一世界,一花一天堂。君掌盛无边,刹那成永恒。