I'm an sklearn dummy... I'm trying to predict the label for a given string from a RandomForestClassifier() fitted with text, labels.

It's obvious I don't know how to use predict() with a single string. The reason I'm using reshape() is because I got this error some time ago "Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample."

How can I predict the label of a single text string?

The script:

#!/usr/bin/env python ''' Read a txt file consisting of '<label>: <long string of text>' to use as a model for predicting the label for a string ''' from argparse import ArgumentParser import json from sklearn.feature_extraction.text import CountVectorizer from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import LabelEncoder def main(args): ''' args: Arguments obtained by _Get_Args() ''' print('Loading data...') # Load data from args.txtfile and split the lines into # two lists (labels, texts). data = open(args.txtfile).readlines() labels, texts = ([], []) for line in data: label, text = line.split(': ', 1) labels.append(label) texts.append(text) # Print a list of unique labels print(json.dumps(list(set(labels)), indent=4)) # Instantiate a CountVectorizer class and git the texts # and labels into it. cv = CountVectorizer( stop_words='english', strip_accents='unicode', lowercase=True, ) matrix = cv.fit_transform(texts) encoder = LabelEncoder() labels = encoder.fit_transform(labels) rf = RandomForestClassifier() rf.fit(matrix, labels) # Try to predict the label for args.string. prediction = Predict_Label(args.string, cv, rf) print(prediction) def Predict_Label(string, cv, rf): ''' string: str() - A string of text cv: The CountVectorizer class rf: The RandomForestClassifier class ''' matrix = cv.fit_transform([string]) matrix = matrix.reshape(1, -1) try: prediction = rf.predict(matrix) except Exception as E: print(str(E)) else: return prediction def _Get_Args(): parser = ArgumentParser(description='Learn labels from text') parser.add_argument('-t', '--txtfile', required=True) parser.add_argument('-s', '--string', required=True) return parser.parse_args() if __name__ == '__main__': args = _Get_Args() main(args)

The actual learning data text file is 43663 lines long but a sample is in small_list.txt which consists of lines each in the format: <label>: <long text string>

The error is noted in the Exception output: