Machine Learning Ep.2 : Cross Validation

stackpython
2 min readJan 18, 2020

--

สำหรับตอนแรกเราได้รู้จัก Machine Learning แบบคร่าวๆกันไปแล้วนะครับ และได้เห็นตัวอย่างโมเดลแบบง่ายๆ ในตอนที่ 2 นี้ผมจะมาเล่าให้ฟังเรื่อง Cross Validation ว่ามันคืออะไรและสำคัญยังไง ฟังชื่อตอนแรกอาจจะดูน่ากลัวและซับซ้อนแต่หลังจากอ่านบทความนี้ รับประกันได้เลยครับว่าไม่ได้ยากอย่างที่คิด!

เราแบ่งข้อมูลไว้เป็น 2 ส่วนคือ Training Data (สีเขียว) กับ Testing Data (สีดำ)

พอมาทดสอบปรากฏว่า Regression line ทำนายได้แย่จาก Training Data ชุดนี้

คำถาม : แล้วเราจะรู้ได้อย่างไรว่าข้อมูลไหนควรนำไปเป็น Training Data?

ตอบ: Cross Validation ไงหละ!!

Cross Validation คือเครื่องมือที่ช่วยให้เราตัดสินใจได้ว่าเราควรแบ่งข้อมูลส่วนไหนไปเป็น Training Data ด้วยวิธีการหลากหลายต่างๆนาๆครับ

โดยบทความนี้ผมขอนำเสนอวิธีที่เรียกว่า K-Fold Cross Validation โดยอ้างอิงจากตัวอย่างด้านบนครับ

K-Fold Cross Validation คือการที่เราแบ่งข้อมูลเป็นจำนวน K ส่วนโดยการในแต่ละส่วนจะต้องมาจากสุ่มเพื่อที่จะให้ข้อมูลของเรากระจายเท่าๆกัน ยกตัวอย่างเช่น ข้อมูลที่เรียงจากน้อยไปมาก [(1.5,2.2),(2,3),(2.5,2.8),…,(5.6,5.2),(6.1,4.8),(6.8,6)] จะเห็นได้ว่าถ้าเราไม่สุ่ม คู่อันดับที่อยู่ล่างๆ กับที่อยู่บนๆ ก็จะไม่กระจายกัน

4-Fold Cross Validation

หลังจากที่เราแบ่งข้อมูลเป็นส่วนๆเรียบร้อย เราก็จะนำ Training Data ของเราไปหา Regression Line และทดสอบครับว่าข้อมูลชุดไหนที่ทำให้เส้นของเราทำนายได้ดีที่สุดนั่นเอง!!

ภาพจาก stackoverflow.com/questions/53779773/python-linear-regression-best-fit-line-with-residuals

เพิ่มเติม : เราต้องคำนวณ Sum of squared residuals (SSR) ของ Testing Data ในข้อมูลแต่ละชุดเพื่อใช้ในการเปรียบเทียบ ซึ่งเราจะเลือกชุดข้อมูลที่ SSR น้อยที่สุด (SSR คือระยะทางทั้งหมดจากจุดไปเส้น SSR น้อย = Error น้อย = เส้นทำนายได้ดี)

ภาพจาก Youtube : StatQuest with Josh Starmer

นอกจากการหา Training Data ที่ดีที่สุดแล้ว Cross Validation ยังสามารถใช้เปรียบเทียบได้อีกว่าเราควรใช้ วิธีไหนที่เหมาะสมที่สุดในการสร้างโมเดลของเรา จากตัวอย่างด้านบนจะเห็นได้ว่าใน Testing Data วิธีการทำนายที่แม่นยำที่สุดคือ Support Vector machines (SVM) ด้วยความแม่นยำ (18/24 = 75%)

คำถาม : แล้วถ้าเป็นข้อมูลแบบตัวอย่างด้านบนหละจะมีวิธีการเปรียบเทียบยังไง?

ตอบ : ดูค่า SSR แล้วเลือกวิธีที่ SSR น้อยที่สุด

สรุป

  • Cross Validation ใช้เปรียบเทียบว่าข้อมูลชุดไหนดีที่สุดในโมเดล
  • Cross Validation ใช้เปรียบเทียบระหว่างโมเดลได้ว่าโมเดลไหนดีกว่ากัน

หากเพื่อนๆชอบบทความนี้ก็สามารถให้กำลังใจผู้เขียนได้ด้วยกันกด Clap หรือถ้ามีส่วนไหนอยากแนะนำเพิ่มเติมก็สามารถ Response กันเข้ามาเพื่อนำไปพัฒนาบทความถัดๆไป ยินดีรับฟังทุกความเห็นครับ ^^

--

--

stackpython
stackpython

Responses (1)