Drkcore

11 08 2016 Python tensorflow scikit-learn Tweet

DNNをRandom Forest (RF)やSupport Vector Machine (SVM)と比較したい

TensorFlowのDNNチュートリアルだとトレーニングセットとテストセットをファイルから読みだすので、実用的にはちょっと面倒くさい。scikit-learnのよろしく分割してくれるメソッド使ったほうが楽でしょう。

またこScikit-learnとTensorFlowを組み合わせることでそれぞれのアルゴリズムの精度を比較することが簡単にできるので便利。

import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn import cross_validation

iris = datasets.load_iris()
x_train, x_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, test_size=0.4, random_state=0)

classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
classifier.fit(x=x_train, y=y_train, steps=200)
dnn_accuracy_score = classifier.evaluate(x=x_test, y=y_test)["accuracy"]
print('DNN Accuracy: {0:f}'.format(dnn_accuracy_score))

clf = svm.SVC(kernel='linear').fit(x_train, y_train)
svm_accuracy_score = clf.score(x_test, y_test)
print('SVM Accuracy: {0:f}'.format(svm_accuracy_score))

rlf = RandomForestClassifier().fit(x_train, y_train)
rf_accuracy_score = rlf.score(x_test, y_test)
print('RF Accuracy: {0:f}'.format(rf_accuracy_score))

About

  • もう5年目(wishlistありマス♡)
  • 最近はPythonとDeepLearning
  • 日本酒自粛中
  • ドラムンベースからミニマルまで
  • ポケモンGOゆるめ

Tag

Python Deep Learning javascript chemoinformatics Emacs sake and more...

Ad

© kzfm 2003-2021