########################################################################
## Copyright (C) 2006 by Marek Wojciechowski
## <mwojc@p.lodz.pl>
##
## Distributed under the terms of the GNU General Public License (GPL)
## http://www.gnu.org/copyleft/gpl.html
########################################################################
### Digits recognition example for ffnet ###
# The training data set is contained in the file data/ocr.dat.
# This file contains 68 patterns - first 58 are used for training
# and last 10 are used for testing. Each pattern contains 64 inputs
# which define 8x8 bitmap of the digit and the last 10 numbers are
# the targets (10 targets for 10 digits).
# Layered network architecture is used here: (64, 10, 10, 10)
from ffnet import ffnet, mlgraph, readdata
# Generate standard layered network architecture and create network
conec = mlgraph((64,10,10,10))
net = ffnet(conec)
# Read data file
print "READING DATA..."
data = readdata( 'data/ocr.dat', separator = ' ' )
input = data[:, :64] #first 64 columns - bitmap definition
target = data[:, 64:] #the rest - 10 columns for 10 digits
# Train network with scipy tnc optimizer - 58 lines used for training
print "TRAINING NETWORK..."
net.train_tnc(input[:58], target[:58], maxfun = 2000, messages=1)
# Test network - remaining 10 lines used for testing
print
print "TESTING NETWORK..."
output, regression = net.test(input[58:], target[58:], iprint = 2)
############################################################
# Make a plot of a chosen digit along with the network guess
try:
from pylab import *
from random import randint
digitpat = randint(58, 67) #Choose testing pattern to plot
subplot(211)
imshow(input[digitpat].reshape(8,8), interpolation = 'nearest')
subplot(212)
N = 10 # number of digits / network outputs
ind = arange(N) # the x locations for the groups
width = 0.35 # the width of the bars
bar(ind, net(input[digitpat]), width, color='b') #make a plot
xticks(ind+width/2., ('1', '2', '3', '4', '5', '6', '7', '8', '9', '0'))
xlim(-width,N-width)
axhline(linewidth=1, color='black')
title("Trained network (64-10-10-10) guesses a digit above...")
xlabel("Digit")
ylabel("Network outputs")
show()
except ImportError:
print "Cannot make plots. For plotting install matplotlib..."
print \
"""
Note:
Normalization of input/output data is handled automatically in ffnet.
Just use your raw data both at trainig and recalling phase.
"""