Machine Learning/grid.py

From Noisebridge
< Machine Learning
Revision as of 15:11, 8 August 2010 by SpammerHellDontDelete (Talk | contribs)

(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to: navigation, search
#!/usr/bin/env python



import os, sys, traceback
import getpass
from threading import Thread
from subprocess import *

if sys.hexversion < 0x03000000:
	import Queue
else:
	import queue as Queue


# svmtrain and gnuplot executable

if sys.platform != 'win32':
       svmtrain_exe = "../svm-train"
       gnuplot_exe = "/usr/bin/gnuplot"
else:
       # example for windows
       svmtrain_exe = r"..\windows\svm-train.exe"
       gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"

# global parameters and their default values

fold = 5
c_begin, c_end, c_step = -5,  15, 2
g_begin, g_end, g_step =  3, -15, -2
global dataset_pathname, dataset_title, pass_through_string
global out_filename, png_filename

# experimental

telnet_workers = []
ssh_workers = []
nr_local_worker = 1

# process command line options, set global parameters
def process_options(argv=sys.argv):

    global fold
    global c_begin, c_end, c_step
    global g_begin, g_end, g_step
    global dataset_pathname, dataset_title, pass_through_string
    global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
    
    usage = """\
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
[additional parameters for svm-train] dataset"""

    if len(argv) < 2:
        print(usage)
        sys.exit(1)

    dataset_pathname = argv[-1]
    dataset_title = os.path.split(dataset_pathname)[1]
    out_filename = '%s.out' % dataset_title
    png_filename = '%s.png' % dataset_title
    pass_through_options = []

    i = 1
    while i < len(argv) - 1:
        if argv[i] == "-log2c":
            i += 1
            c_begin,c_end,c_step = map(float,argv[i].split(","))
        elif argv[i] == "-log2g":
            i += 1
            g_begin,g_end,g_step = map(float,argv[i].split(","))
        elif argv[i] == "-v":
            i += 1
            fold = argv[i]
        elif argv[i] in ('-c','-g'):
            print("Option -c and -g are renamed.")
            print(usage)
            sys.exit(1)
        elif argv[i] == '-svmtrain':
            i += 1
            svmtrain_exe = argv[i]
        elif argv[i] == '-gnuplot':
            i += 1
            gnuplot_exe = argv[i]
        elif argv[i] == '-out':
            i += 1
            out_filename = argv[i]
        elif argv[i] == '-png':
            i += 1
            png_filename = argv[i]
        else:
            pass_through_options.append(argv[i])
        i += 1

    pass_through_string = " ".join(pass_through_options)
    assert os.path.exists(svmtrain_exe),"svm-train executable not found" 
    assert os.path.exists(dataset_pathname),"dataset not found"


def range_f(begin,end,step):
    # like range, but works on non-integer too
    seq = []
    while True:
        if step > 0 and begin > end: break
        if step < 0 and begin < end: break
        seq.append(begin)
        begin += step
    return seq

def permute_sequence(seq):
    n = len(seq)
    if n <= 1: return seq

    mid = int(n/2)
    left = permute_sequence(seq[:mid])
    right = permute_sequence(seq[mid+1:])

    ret = [seq[mid]]
    while left or right:
        if left: ret.append(left.pop(0))
        if right: ret.append(right.pop(0))

    return ret

def redraw(db,best_param,tofile=False):
    if len(db) == 0: return
    begin_level = round(max(x[2] for x in db)) - 3
    step_size = 0.5

    best_log2c,best_log2g,best_rate = best_param

    


    
    db.sort(key = lambda x:(x[0], -x[1]))

    prevc = db[0][0]
    for line in db:
        if prevc != line[0]:
            prevc = line[0]


def calculate_jobs():
    c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
    g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
    nr_c = float(len(c_seq))
    nr_g = float(len(g_seq))
    i = 0
    j = 0
    jobs = []

    while i < nr_c or j < nr_g:
        if i/nr_c < j/nr_g:
            # increase C resolution
            line = [(c_seq[i],g_seq[k]) for k in range(j)]
            i += 1
            jobs.append(line)
        else:
            # increase g resolution
            line = [(c_seq[k],g_seq[j]) for k in range(i)]
            j += 1
            jobs.append(line)
    return jobs

class WorkerStopToken:  # used to notify the worker to stop
        pass

class Worker(Thread):
    def __init__(self,name,job_queue,result_queue):
        Thread.__init__(self)
        self.name = name
        self.job_queue = job_queue
        self.result_queue = result_queue
    def run(self):
        while True:
            cexp,gexp = self.job_queue.get()
            if cexp is WorkerStopToken:
                self.job_queue.put((cexp,gexp))
                # print 'worker %s stop.' % self.name
                break
            try:
                rate = self.run_one(2.0**cexp,2.0**gexp)
                if rate is None: raise "get no rate"
            except:
                # we failed, let others do that and we just quit
            
                traceback.print_exception(*sys.exc_info())
                
                self.job_queue.put((cexp,gexp))
                print('worker %s quit.' % self.name)
                break
            else:
                self.result_queue.put((self.name,cexp,gexp,rate))

class LocalWorker(Worker):
    def run_one(self,c,g):
        cmdline = '%s -c %s -g %s -v %s %s %s' % \
          (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
        result = Popen(cmdline,shell=True,stdout=PIPE).stdout
        for line in result:
            if "Cross" in str(line):
                return float(line.split()[-1][0:-1])

class SSHWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host):
        Worker.__init__(self,name,job_queue,result_queue)
        self.host = host
        self.cwd = os.getcwd()
    def run_one(self,c,g):
        cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \
          (self.host,self.cwd,
           svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
        result = Popen(cmdline,shell=True,stdout=PIPE).stdout
        for line in result:
            if "Cross" in str(line):
                return float(line.split()[-1][0:-1])

class TelnetWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host,username,password):
        Worker.__init__(self,name,job_queue,result_queue)
        self.host = host
        self.username = username
        self.password = password        
    def run(self):
        import telnetlib
        self.tn = tn = telnetlib.Telnet(self.host)
        tn.read_until("login: ")
        tn.write(self.username + "\n")
        tn.read_until("Password: ")
        tn.write(self.password + "\n")

        # XXX: how to know whether login is successful?
        tn.read_until(self.username)
        # 
        print('login ok', self.host)
        tn.write("cd "+os.getcwd()+"\n")
        Worker.run(self)
        tn.write("exit\n")               
    def run_one(self,c,g):
        cmdline = '%s -c %s -g %s -v %s %s %s' % \
          (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
        result = self.tn.write(cmdline+'\n')
        idx,matchm,output = self.tn.expect(['Cross.*\n'])
        for line in output.splitlines():
            if "Cross" in str(line):
                return float(line.split()[-1][0:-1])

def main():

    # set parameters

    process_options()

    # put jobs in queue

    jobs = calculate_jobs()
    job_queue = Queue.Queue(0)
    result_queue = Queue.Queue(0)

    for line in jobs:
        for (c,g) in line:
            job_queue.put((c,g))

    job_queue._put = job_queue.queue.appendleft


    # fire telnet workers

    if telnet_workers:
        nr_telnet_worker = len(telnet_workers)
        username = getpass.getuser()
        password = getpass.getpass()
        for host in telnet_workers:
            TelnetWorker(host,job_queue,result_queue,
                     host,username,password).start()

    # fire ssh workers

    if ssh_workers:
        for host in ssh_workers:
            SSHWorker(host,job_queue,result_queue,host).start()

    # fire local workers

    for i in range(nr_local_worker):
        LocalWorker('local',job_queue,result_queue).start()

    # gather results

    done_jobs = {}


    result_file = open(out_filename, 'w')


    db = []
    best_rate = -1
    best_c1,best_g1 = None,None

    for line in jobs:
        for (c,g) in line:
            while (c, g) not in done_jobs:
                (worker,c1,g1,rate) = result_queue.get()
                done_jobs[(c1,g1)] = rate
                result_file.write('%s %s %s\n' %(c1,g1,rate))
                result_file.flush()
                if rate > best_rate or (rate==best_rate and g1==best_g1 and c1<best_c1):
                    best_rate = rate
                    best_c1,best_g1=c1,g1
                    best_c = 2.0**c1
                    best_g = 2.0**g1
                print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \
		    (worker,c1,g1,rate, best_c, best_g, best_rate))
            db.append((c,g,done_jobs[(c,g)]))
        redraw(db,[best_c1, best_g1, best_rate])
        redraw(db,[best_c1, best_g1, best_rate],True)

    job_queue.put((WorkerStopToken,None))
    print("%s %s %s" % (best_c, best_g, best_rate))
main()
Personal tools