Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features.

For instance, in the example below, decision trees learn from data to determine the preferred music genre based on the year and gender of the person. The deeper the tree, the more complex the decision rules and the fitter the model.

DecisionTreeClassifier is a class capable of performing multi-class classification on a dataset.

As with other classifiers, DecisionTreeClassifier takes as input two arrays: an array X, sparse or dense, of size [n_samples, n_features] holding the training samples, and an array Y of integer values, size [n_samples], holding the class labels for the training samples:



from sklearn import tree import sys # age # sex [0: male, 1: female] features = [ [ 18 , 0 ], [ 19 , 0 ], [ 22 , 0 ], [ 25 , 0 ], [ 28 , 0 ], [ 31 , 0 ], [ 34 , 0 ], [ 40 , 0 ], [ 45 , 0 ], [ 18 , 1 ], [ 19 , 1 ], [ 22 , 1 ], [ 25 , 1 ], [ 28 , 1 ], [ 31 , 1 ], [ 34 , 1 ], [ 40 , 1 ], [ 45 , 1 ] ] # music genre labels = [ 'rap' , 'rap' , 'hip hop' , 'hip hop' , 'rock' , 'rock' , 'rock' , 'country' , 'country' , 'dance' , 'dance' , 'hip hop' , 'hip hop' , 'rap' , 'rap' , 'rap' , 'classical' , 'classical' ] clf = tree . DecisionTreeClassifier () clf . fit ( features , labels ) # pass age and sex as script params with sys.argv prediction = clf . predict ([[ sys . argv [ 1 ], sys . argv [ 2 ]]]) print ( prediction )

Try it!



python3.7 decision_tree_classifier.py 18 1 [ 'dance' ]

the tree can also be exported in textual format with the function export_text.



from sklearn.tree.export import export_text decision_tree_text = export_text ( clf , feature_names = [ 'age' , 'sex' ]) print ( decision_tree_text )