题面
花花有一棵带 n 个顶点的树 T,每个节点有一个点权 ai。 有一天,他认为拥有两棵树更好一些。所以,他从 T 中删去了一条边。 第二天,他认为三棵树或许又更好一些。因此,他又从他拥有的某一棵树中去除了一条边。 如此往复。每一天,花花都会删去一条尚未被删去的边,直到他得到了一个包含了 n 棵只有一个点的树的森林。 定义一条简单路径(顶点不重复的路径)的权值为路径上点权之和,一棵树的直径为树上权值最大的简单路径。花花认为树最重要的特征就是它的直径。所以他想请你算出任一时刻他拥有的所有树的直径的乘积。因为这个数可能很大,他要求你输出乘积对 109 + 7 取模之后的结果。
Input
输入的第一行包含一个整数 n,表示树 T 上顶点的数量。 下一行包含 n 个空格分隔的整数 ai,表示顶点的权值。 之后的 n - 1 行中,每一行包含两个用空格分隔的整数 xi 和 yi,表示节点 xi 和 yi 之间连有一条边,编号为 i。 再之后 n - 1 行中,每一行包含一个整数 kj,表示在第 j 天里会被删除的边的编号。
Output
输出 n 行。 在第 i 行,输出删除 i - 1 条边之后,所有树直径的乘积对 109 + 7 取模的结果。
分析
当时考试的时候想的是lca暴力乱搞一通,然后WA完了:(
然后考完了又是追悔莫及
这种正难则反,删边等效为加边的题做了多少次了..还是没长记性
思路大概就出来了,从每个节点为一棵树,没有边的末状态倒着加边,加边的时候合并树,并求直径
这里就有个简单的结论,如果a,b是树1的直径端点,c,d是树2的直径端点
则合并后的树的直径端点一定是a,b,c,d其中两个,这个结论是很显然的,根据两次dfs求直径的思想,对于树1中任意一点出发,能搜到距离最远的点作为直径的一个端点,再从此端点出发,搜到一个与它距离最远的点,就得到了直径。
所以反过来,对于树1上任意一点,到它距离最远的点要么是a,要么是b,树2同理。
接下来每次把合并后的直径乘进去,合并前的两个直径除掉(逆元)
debug好久。。先挂lca,后挂逆元,干脆我自挂东南枝算了。。。
不知道为啥费马小定理求的逆元不太对..理论上没毛病啊??被迫改成exgcd
代码
#includeusing namespace std;#define N 100100#define ll long long#define mod 1000000007long long ans,len;long long print[N];int n,m,t,cnt;int val[N],dep[N],fas[N],first[N];int d[N],q[N],ue[N],ve[N];int fa[N][20];struct email{ int u,v; int nxt;}e[N*4];struct _dot_{ int d1,d2; }tmp,dot[N*2];inline int find(int x){ return fas[x]==x?x:fas[x]=find(fas[x]);}inline void read(int &x){ x=0;char ch=getchar(); while(ch>'9'||ch<'0'){ch=getchar();} while(ch<='9'&&ch>='0'){x=x*10+ch-'0';ch=getchar();}}inline void add(int u,int v){ e[++cnt].nxt=first[u];first[u]=cnt; e[cnt].u=u;e[cnt].v=v;}void dfs(int u,int f){ fa[u][0]=f; d[u]=d[f]+val[u]; dep[u]=dep[f]+1; for(int i=1;i<=17;i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=first[u];i;i=e[i].nxt) { int v=e[i].v; if(v==f)continue; dfs(v,u); }}int lca(int x,int y){ if(dep[x] =0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0];}inline int dis(int u,int v,int l){ return d[u]+d[v]-2*d[l]+val[l];} int x,y;void exgcd(int a,int b){ if(!b) { x=1,y=0; return; } exgcd(b,a%b); int t=x; x=y; y=t-a/b*y;}int inv(int k){ exgcd(k,mod); return (x%mod+mod)%mod;}int main(){ memset(d,0,sizeof(d)); read(n);print[n]=1; for(int i=1;i<=n;i++) { read(val[i]); print[n]=(print[n]*val[i])%mod; fas[i]=dot[i].d1=dot[i].d2=i; } for(int i=1;i =1;i--) { int u=ue[q[i]],v=ve[q[i]]; int fax=find(u),fay=find(v); fas[fax]=fay; len=-1; ll s1=dis(dot[fax].d1,dot[fay].d2,lca(dot[fax].d1,dot[fay].d2));//12 ll s2=dis(dot[fax].d1,dot[fay].d1,lca(dot[fax].d1,dot[fay].d1));//11 ll s3=dis(dot[fax].d2,dot[fay].d2,lca(dot[fax].d2,dot[fay].d2));//22 ll s4=dis(dot[fax].d2,dot[fay].d1,lca(dot[fax].d2,dot[fay].d1));//21 ll s5=dis(dot[fax].d1,dot[fax].d2,lca(dot[fax].d1,dot[fax].d2)); ll s6=dis(dot[fay].d1,dot[fay].d2,lca(dot[fay].d1,dot[fay].d2)); len=max(s1,s2);len=max(len,s3); len=max(len,s4);len=max(len,s5);len=max(len,s6); if(len==s1) tmp.d1=dot[fax].d1,tmp.d2=dot[fay].d2; if(len==s2) tmp.d1=dot[fax].d1,tmp.d2=dot[fay].d1; if(len==s3) tmp.d1=dot[fax].d2,tmp.d2=dot[fay].d2; if(len==s4) tmp.d1=dot[fax].d2,tmp.d2=dot[fay].d1; if(len==s5) tmp=dot[fax];if(len==s6) tmp=dot[fay]; dot[fay]=tmp; print[i]=print[i+1]*len%mod; print[i]=(print[i]*inv(s5))%mod; print[i]=(print[i]*inv(s6))%mod; } for(int i=1;i<=n;i++) printf("%d\n",print[i]); return 0;}