Skip to content

Commit

Permalink
Merge pull request #1 from masaid24/master
Browse files Browse the repository at this point in the history
added usage comments
  • Loading branch information
hammamikhairi authored Feb 10, 2023
2 parents ee5c893 + a64d208 commit 7010a88
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions Tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ import (

type Counter map[string]int

// The Decision Tree
type DecisionTree struct {
root *node
}

// Initialize the Decision Tree
//
// Parameters
// ----------
// Y : []int : The target variable to predict
// X : dataframe.DataFrame <float64 & float64> : The features
// maxTreeDepth : int : The maximum depth of the tree
// minDfSplit : int : The minimum number of samples to split a node
func TreeInit(Y []int, X dataframe.DataFrame, maxTreeDepth, minDfSplit int) *DecisionTree {
tree := &DecisionTree{
root: NodeInit(Y, X, 0, maxTreeDepth, minDfSplit, "ROOT"),
Expand Down Expand Up @@ -94,10 +103,16 @@ func meth(col []float64) []float64 {
return methed
}

// Generate tree
func (tree *DecisionTree) Sprout() {
tree.root.sprout()
}

// Predict the target variable for a given set of features
//
// Parameters
// ----------
// X : dataframe.DataFrame : The features
func (tree *DecisionTree) Predict(data dataframe.DataFrame) []string {

features := tree.root.data.features
Expand All @@ -115,6 +130,7 @@ func (tree *DecisionTree) Predict(data dataframe.DataFrame) []string {
return predictions
}

// Print the tree in a human readable format
func (tree *DecisionTree) Print() {
tree.root.print(1)
}

0 comments on commit 7010a88

Please sign in to comment.