LOJ#3381 函数调用 题解

分析

函数调用,满足不出现递归,本身就构成了一张 DAG。

由于乘法操作是全局乘法,所以可以提前预处理乘法操作,而加法操作最终的贡献,取决于在它之后乘上多少次。

\(pls_i\) 为第 \(i\) 个函数执行的加法的值,如果没有的话则为 \(0\)\(id_i\) 为第 \(i\) 个函数要执行加法的元素的编号,没有的话为 \(0\)\(mul_i\) 为第 \(i\) 个函数对全局的乘法值,如果没有的话则为 \(1\)

\(cnt_i\) 为第 \(i\) 个函数加法操作之后又被乘了多少。

对于函数 \(i\) 要调用的函数 \(x\),连边 \((i \rightarrow x)\)。用 DFS 可以预处理出当运行完任意一个函数时,对全局的乘法操作,而函数 \(i\) 必然在 \(x\) 之后操作,因此 \(x\) 加法操作的最终值依赖于 \(i\),因此处理 \(i\) 要早于处理 \(x\),拓扑排序即可。

但是函数的调用是有序的,依次从左到右。因此必须新建立一个节点 \(0\),向所有调用的函数连边,并且在遍历 \(x\) 能到达的所有点时,必须要按照倒序。

DFS 之后,\(mul_0\) 即为全局最终乘上的值,而对于每一个加法函数 \(i\),都会对 \(a_{id_i}\) 产生 \(pls_i \cdot cnt_i\) 的贡献。

CODE

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
int read() {
	int a=0, f=1; char c=getchar();
	while(!isdigit(c)) {
		if(c=='-') f=-1;
		c=getchar();
	}
	while(isdigit(c)) a=a*10+c-'0', c=getchar();
	return a*f;
}
const int M=1e6+5, N=1e5+5, mod=998244353;
int n, m, a[N], id[N], pls[N], mul[N], cnt[N];
int in[N], op[N];
bool v[N];
vector<int> p[M];
void dfs(int x) {
	if(v[x]) return;
	v[x]=1;
    // 关于为什么这里要用v数组来判环
    // 因为从0连边之前是DAG,但是连边之后就可能不是了
	for(auto y:p[x]) {
		dfs(y);
		(mul[x]*=mul[y])%=mod;
	}
}
void toposort() {
	queue<int> q;
	for(int i=0;i<=m;++i) if(!in[i]) q.push(i);
	while(q.size()) {
		int x=q.front(); q.pop();
		int dlt=1;
		for(int i=p[x].size()-1;~i;--i) {
        // 倒序遍历
			int y=p[x][i];
			(cnt[y]+=cnt[x]*dlt)%=mod;
            // 这里的层次一定要分清楚
            // y的加法操作只会被x影响,要按照拓扑序处理,而不是看全局乘上的值
			(dlt*=mul[y])%=mod;
			if(--in[y]==0) q.push(y);
		}
	}
}
signed main() {
	freopen("call.in","r",stdin);
	freopen("call.out","w",stdout);
	n=read();
	for(int i=1;i<=n;++i) a[i]=read();
	m=read();
	for(int i=1;i<=m;++i) {
		op[i]=read();
		if(op[i]==1) id[i]=read(), pls[i]=read(), mul[i]=1;
		else if(op[i]==2) mul[i]=read();
		else {
			int k=read();
			mul[i]=1;
			for(int j=1;j<=k;++j) {
				int x=read();
				p[i].pb(x), ++in[x];
			}
		}
	}
	int q=read();
	for(int i=1;i<=q;++i) {
		int x=read();
		p[0].pb(x), ++in[x];
	}
	cnt[0]=mul[0]=1;
	dfs(0);
	toposort();
	for(int i=1;i<=n;++i) (a[i]*=mul[0])%=mod;
	for(int i=1;i<=m;++i) if(op[i]==1) (a[id[i]]+=cnt[i]*pls[i]%mod)%=mod;
	for(int i=1;i<=n;++i) printf("%lld%c",a[i]," \n"[i==n]);
}

LOJ#3381 函数调用 题解
https://yozora0908.github.io/2022/loj3381-solution/
作者
yozora0908
发布于
2022年8月20日
许可协议