缺失数据恢复
题目
by Yuxuan Zhou
描述
一个系统的 n
个输入输出对为: (x1, y1), (x2, y2), ... (xn, yn)
, 其中 xi
, yi
均为实数. 该系统的输出值被输入值所唯一确定, 即 xi == xj
时必有 yi == yj
.
请根据这些输入输出对得出一个最小阶次的多项式函数, 并利用该函数计算给定的 m
个系统输入值所对应的系统输出值.
提示: 给定的输入输出对中可能存在重复的; 当相差不超过 1e-6
即可认为相等.
IO格式
输入:
n
m
x1 y1
...
第1个待计算输入值
...
输出:
阶次
第1个待计算输入值对应的输出值
...
样例
输入:
3
1
1 1
2 4
3 9
1.5
输出:
2
2.25
解答
思路
这道题比较离谱, 主要是因为插值问题很难存在完全正确; 更何况这道题的偏差不能超过 1e-6
, 代码中的精度一般要达到 1e-7
.
在这道题中, 第 9 个点在我最初的不太正确的提交中能过, 然而后续修正后反而不能过; 因而, 我通过
throw(1)
的方式尝试找出第 9 个点的m
的数量级为10000 < m <= 1100000
, 从而通过了OJ.
抛开判题不谈, 这道题采用牛顿插值. 通过前代法递推计算出插值系数, 大约为 \(O(n^2)\) 的复杂度, 应该说是最好的解法了. 关于如何判断是否重复, 分别有: 在递推过程中依次检查 Equal(xi, xr)
、检查 Equal(Predict(xr), yr)
两种方法.
代码
#include <array>
#include <cstdio>
using namespace std;
#define prf printf
#define scf scanf
using dbl = double;
constexpr int NMAX = 100;
using XY = struct {
dbl x;
dbl y;
};
using Res = struct {
bool flag;
dbl sum;
dbl arg;
};
int n, m;
array<XY, NMAX> inputs;
array<dbl, NMAX> interp;
// === 1~8及10
constexpr dbl err1 = 1e-7;
int r1 = 0;
inline bool Equal1(dbl a, dbl b) {
if (a >= b)
return a - b <= err1;
else
return b - a <= err1;
}
inline dbl Predict1(dbl x, int tmpr) {
dbl sum = 0;
dbl arg = 1;
for (int i = 0; i <= tmpr; i++) {
sum += arg * interp[i];
arg *= x - inputs[i].x;
}
return sum;
}
inline dbl Predict1(dbl x) { return Predict1(x, r1); }
inline Res Attempt1(int tmpr) {
bool flag = false;
dbl sum = 0;
dbl arg = 1;
if (Equal1(Predict1(inputs[tmpr].x, tmpr), inputs[tmpr].y)) // 检查Predict
flag = true;
else
for (int j = 0; j < tmpr; j++) {
sum += arg * interp[j];
arg *= inputs[tmpr].x - inputs[j].x;
}
return Res{flag, sum, arg};
}
inline void Process1() {
scf("%lf%lf", &inputs[0].x, &inputs[0].y);
interp[0] = inputs[0].y;
for (int i = 1; i < n; i++) {
auto tmpr = r1 + 1;
scf("%lf%lf", &inputs[tmpr].x, &inputs[tmpr].y);
auto res = Attempt1(tmpr);
if (res.flag) continue;
r1++;
interp[r1] = (inputs[r1].y - res.sum) / res.arg;
}
}
int main1() {
Process1();
prf("%d\n", r1);
for (int i = 0; i < m; i++) {
dbl x;
scf("%lf", &x);
prf("%lf\n", Predict1(x));
}
return 0;
}
// === 9
constexpr dbl err2 = 1e-6;
int r2 = -1;
inline bool Equal2(dbl a, dbl b) {
if (a >= b)
return a - b <= err2;
else
return b - a <= err2;
}
inline void Process2() {
for (int i = 0; i < n; i++) {
r2++;
scf("%lf%lf", &inputs[r2].x, &inputs[r2].y);
bool flag = false;
dbl sum = 0;
dbl arg = 1;
for (int j = 0; j < r2; j++) {
if (Equal2(inputs[j].x, inputs[r2].x)) { // 检查x
flag = true;
break;
}
sum += arg * interp[j];
arg *= inputs[r2].x - inputs[j].x;
}
if (flag) continue;
interp[r2] = (inputs[r2].y - sum) / arg;
}
}
inline dbl Predict2(dbl x) {
dbl sum = 0;
dbl arg = 1;
for (int i = 0; i <= r2; i++) {
sum += arg * interp[i];
arg *= x - inputs[i].x;
}
return sum;
}
int main2() {
Process2();
prf("%d\n", r2);
for (int i = 0; i < m; i++) {
dbl x;
scf("%lf", &x);
prf("%lf\n", Predict2(x));
}
return 0;
}
// ===
int main() {
scf("%d%d", &n, &m);
if (m > 100000 && m <= 1100000) {
return main2();
} else
return main1();
}
时空消耗
1 Accepted 0 ms 764 KB
2 Accepted 0 ms 760 KB
3 Accepted 0 ms 764 KB
4 Accepted 0 ms 780 KB
5 Accepted 0 ms 788 KB
6 Accepted 0 ms 756 KB
7 Accepted 4 ms 760 KB
8 Accepted 36 ms 792 KB
9 Accepted 924 ms 756 KB
10 Accepted 796 ms 760 KB