คำอธิบายวิธีทำพร้อม code สำหรับข้อ toi21_quartet

สรุปโจทย์

ในเมืองนครแห่งหนึ่ง มีสถานที่อยู่ 33 ที่ ได้แก่

สิ่งที่ต้องทำ

นับจำนวน “กลุ่มอพยพปลอดภัย” โดยให้ระบุคำตอบในรูปแบบของเศษจากการหารจำนวนของ “กลุ่มอพยพปลอดภัย” ด้วย 109+710^9 + 7

ขอบเขตข้อมูล

44 \leq NN \leq 10510^5
11 \leq MM \leq N2N - 2

ข้อสังเกต

จำนวนถนน จะเท่ากับ N+M1N + M - 1 ซึ่งหมายความว่า Graph ที่ได้มาจะมีลักษณะเป็น Tree

ไอเดียหลัก

เราจะเก็บว่า สำหรับแต่ละศูนย์ตำบล uu ถ้าเราเลือกให้หมู่บ้าน 44 แห่งเดินทางมายังศูนย์ตำบล uu โดยไม่ใช้ถนนสายเดียวกันเลย เราจะนับได้กี่กลุ่ม (นับจำนวนกลุ่มอพยพปลอดภัยที่สามารถเดินทางมา uu ได้)

Prerequisites

วิธีทำ

เราจะทำ Dynamic Programming on Tree บน Tree ที่ได้รับมา โดยเราจะเก็บใน array of vector (ตั้งชื่อว่า cntcnt) ว่า ถ้าเราอยู่ที่โหนด uu แล้ว หากเราเดินตามถนนแต่ละสายที่ติดกับโหนด uu (เดินตาม adjacency list) แล้ว จะมีหมู่บ้านจำนวนกี่หมู่บ้านในเส้นทางนั้น (นับจำนวนหมู่บ้านที่อยู่ใน subtree ของ adj[u][i]adj[u][i] แล้วเก็บไว้ใน cnt[u][i]cnt[u][i]) นั่นคือ ขนาดของ cnt[u]cnt[u] สำหรับแต่ละ uu อาจจะมีขนาดไม่เท่ากันนั่นเอง
เมื่อเราได้ cntcnt มาแล้ว ทีนี้ เราจะทำการ loop จาก N+1N + 1 ถึง N+MN + M (loop แค่ศูนย์ตำบล) แล้วจะคำนวณว่า ถ้าหากว่าเราจะนับจากศูนย์ตำบลที่ uu แล้ว จะมี “กลุ่มอพยพปลอดภัย” ทั้งหมดกี่กลุ่ม และเมื่อเราคำนวณครบสำหรับศูนย์ตำบลทั้ง MM ศูนย์แล้ว เราจะได้คำตอบออกมา
วิธีการคำนวณ:

Summary

Solution Code:

#include <bits/stdc++.h> 

using namespace std; 

#define int long long 

const int mod = 1e9 + 7; 
const int N = 1e5 + 5; 
const int M = 1e5 + 5; 

int n, m, ans, sz[N + M]; 
vector <int> adj[N + M], cnt[N + M]; 

void dfs(int u, int p){ 
	sz[u] = (u <= n); 
	for (auto v : adj[u]) { 
		if (v == p) continue; 
		dfs(v, u); 
		if (sz[v] > 0) cnt[u].emplace_back(sz[v]), sz[u] += sz[v]; 
	} 
	cnt[u].emplace_back(n - sz[u]); 
} 

int32_t main(){ 
	cin.tie(NULL)->sync_with_stdio(false); 
	cin >> n >> m; 
	for (int i = 1; i < n + m; i++) { 
		int u, v; cin >> u >> v; 
		adj[u].emplace_back(v); 
		adj[v].emplace_back(u); 
	} 
	dfs(1, 1); 
	for (int u = n + 1; u <= n + m; u++) { 
		int siz = adj[u].size(); 
		if (siz >= 4) { 
			int dp[siz + 1][5]; 
			memset(dp, 0, sizeof dp); 
			dp[0][0] = 1; 
			for (int i = 1; i <= siz; i++) { 
				dp[i][0] = 1; 
				for (int j = 1; j <= 4; j++) { 
					dp[i][j] = dp[i - 1][j] + (dp[i - 1][j - 1] * cnt[u][i - 1]); 
				} 
			} 
			ans += dp[siz][4]; ans %= mod; 
		} 
	} 
	cout << ans; 
}

Total Time Complexity: O(n+m)O(n + m)

หากมีข้อสงสัย comment ไว้ใต้ post ได้เลยนะครับ 🙇‍♂️🙇‍♂️
ศึกษาโจทย์เพิ่มเติมได้ที่ Fast X Fourier