Machine Learning/grid.py: Difference between revisions
Jump to navigation
Jump to search
ThomasLotze (talk | contribs) (Created page with '<pre> #!/usr/bin/env python import os, sys, traceback import getpass from threading import Thread from subprocess import * if(sys.hexversion < 0x03000000): import Queue else…') |
No edit summary |
||
Line 9: | Line 9: | ||
from subprocess import * | from subprocess import * | ||
if | if sys.hexversion < 0x03000000: | ||
import Queue | import Queue | ||
else: | else: | ||
Line 17: | Line 17: | ||
# svmtrain and gnuplot executable | # svmtrain and gnuplot executable | ||
if sys.platform != 'win32': | |||
svmtrain_exe = "../svm-train" | svmtrain_exe = "../svm-train" | ||
gnuplot_exe = "/usr/bin/gnuplot" | gnuplot_exe = "/usr/bin/gnuplot" | ||
Line 67: | Line 66: | ||
while i < len(argv) - 1: | while i < len(argv) - 1: | ||
if argv[i] == "-log2c": | if argv[i] == "-log2c": | ||
i = | i += 1 | ||
c_begin,c_end,c_step = map(float,argv[i].split(",")) | |||
elif argv[i] == "-log2g": | elif argv[i] == "-log2g": | ||
i = | i += 1 | ||
g_begin,g_end,g_step = map(float,argv[i].split(",")) | |||
elif argv[i] == "-v": | elif argv[i] == "-v": | ||
i = | i += 1 | ||
fold = argv[i] | fold = argv[i] | ||
elif argv[i] in ('-c','-g'): | elif argv[i] in ('-c','-g'): | ||
Line 80: | Line 79: | ||
sys.exit(1) | sys.exit(1) | ||
elif argv[i] == '-svmtrain': | elif argv[i] == '-svmtrain': | ||
i = | i += 1 | ||
svmtrain_exe = argv[i] | svmtrain_exe = argv[i] | ||
elif argv[i] == '-gnuplot': | elif argv[i] == '-gnuplot': | ||
i = | i += 1 | ||
gnuplot_exe = argv[i] | gnuplot_exe = argv[i] | ||
elif argv[i] == '-out': | elif argv[i] == '-out': | ||
i = | i += 1 | ||
out_filename = argv[i] | out_filename = argv[i] | ||
elif argv[i] == '-png': | elif argv[i] == '-png': | ||
i = | i += 1 | ||
png_filename = argv[i] | png_filename = argv[i] | ||
else: | else: | ||
pass_through_options.append(argv[i]) | pass_through_options.append(argv[i]) | ||
i = | i += 1 | ||
pass_through_string = " ".join(pass_through_options) | pass_through_string = " ".join(pass_through_options) | ||
Line 107: | Line 106: | ||
if step < 0 and begin < end: break | if step < 0 and begin < end: break | ||
seq.append(begin) | seq.append(begin) | ||
begin = | begin += step | ||
return seq | return seq | ||
Line 156: | Line 155: | ||
if i/nr_c < j/nr_g: | if i/nr_c < j/nr_g: | ||
# increase C resolution | # increase C resolution | ||
line = [ | line = [(c_seq[i],g_seq[k]) for k in range(j)] | ||
i += 1 | |||
i = | |||
jobs.append(line) | jobs.append(line) | ||
else: | else: | ||
# increase g resolution | # increase g resolution | ||
line = [ | line = [(c_seq[k],g_seq[j]) for k in range(i)] | ||
j += 1 | |||
j = | |||
jobs.append(line) | jobs.append(line) | ||
return jobs | return jobs | ||
Line 181: | Line 176: | ||
def run(self): | def run(self): | ||
while True: | while True: | ||
cexp,gexp = self.job_queue.get() | |||
if cexp is WorkerStopToken: | if cexp is WorkerStopToken: | ||
self.job_queue.put((cexp,gexp)) | self.job_queue.put((cexp,gexp)) | ||
Line 192: | Line 187: | ||
# we failed, let others do that and we just quit | # we failed, let others do that and we just quit | ||
traceback.print_exception(sys.exc_info() | traceback.print_exception(*sys.exc_info()) | ||
self.job_queue.put((cexp,gexp)) | self.job_queue.put((cexp,gexp)) | ||
Line 205: | Line 200: | ||
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | ||
result = Popen(cmdline,shell=True,stdout=PIPE).stdout | result = Popen(cmdline,shell=True,stdout=PIPE).stdout | ||
for line in result | for line in result: | ||
if str(line) | if "Cross" in str(line): | ||
return float(line.split()[-1][0:-1]) | return float(line.split()[-1][0:-1]) | ||
Line 219: | Line 214: | ||
svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | ||
result = Popen(cmdline,shell=True,stdout=PIPE).stdout | result = Popen(cmdline,shell=True,stdout=PIPE).stdout | ||
for line in result | for line in result: | ||
if str(line) | if "Cross" in str(line): | ||
return float(line.split()[-1][0:-1]) | return float(line.split()[-1][0:-1]) | ||
Line 248: | Line 243: | ||
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) | ||
result = self.tn.write(cmdline+'\n') | result = self.tn.write(cmdline+'\n') | ||
idx,matchm,output = self.tn.expect(['Cross.*\n']) | |||
for line in output. | for line in output.splitlines(): | ||
if str(line) | if "Cross" in str(line): | ||
return float(line.split()[-1][0:-1]) | return float(line.split()[-1][0:-1]) | ||
Line 312: | Line 307: | ||
result_file.write('%s %s %s\n' %(c1,g1,rate)) | result_file.write('%s %s %s\n' %(c1,g1,rate)) | ||
result_file.flush() | result_file.flush() | ||
if | if rate > best_rate or (rate==best_rate and g1==best_g1 and c1<best_c1): | ||
best_rate = rate | best_rate = rate | ||
best_c1,best_g1=c1,g1 | best_c1,best_g1=c1,g1 |
Latest revision as of 15:11, 8 August 2010
#!/usr/bin/env python import os, sys, traceback import getpass from threading import Thread from subprocess import * if sys.hexversion < 0x03000000: import Queue else: import queue as Queue # svmtrain and gnuplot executable if sys.platform != 'win32': svmtrain_exe = "../svm-train" gnuplot_exe = "/usr/bin/gnuplot" else: # example for windows svmtrain_exe = r"..\windows\svm-train.exe" gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe" # global parameters and their default values fold = 5 c_begin, c_end, c_step = -5, 15, 2 g_begin, g_end, g_step = 3, -15, -2 global dataset_pathname, dataset_title, pass_through_string global out_filename, png_filename # experimental telnet_workers = [] ssh_workers = [] nr_local_worker = 1 # process command line options, set global parameters def process_options(argv=sys.argv): global fold global c_begin, c_end, c_step global g_begin, g_end, g_step global dataset_pathname, dataset_title, pass_through_string global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename usage = """\ Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname] [additional parameters for svm-train] dataset""" if len(argv) < 2: print(usage) sys.exit(1) dataset_pathname = argv[-1] dataset_title = os.path.split(dataset_pathname)[1] out_filename = '%s.out' % dataset_title png_filename = '%s.png' % dataset_title pass_through_options = [] i = 1 while i < len(argv) - 1: if argv[i] == "-log2c": i += 1 c_begin,c_end,c_step = map(float,argv[i].split(",")) elif argv[i] == "-log2g": i += 1 g_begin,g_end,g_step = map(float,argv[i].split(",")) elif argv[i] == "-v": i += 1 fold = argv[i] elif argv[i] in ('-c','-g'): print("Option -c and -g are renamed.") print(usage) sys.exit(1) elif argv[i] == '-svmtrain': i += 1 svmtrain_exe = argv[i] elif argv[i] == '-gnuplot': i += 1 gnuplot_exe = argv[i] elif argv[i] == '-out': i += 1 out_filename = argv[i] elif argv[i] == '-png': i += 1 png_filename = argv[i] else: pass_through_options.append(argv[i]) i += 1 pass_through_string = " ".join(pass_through_options) assert os.path.exists(svmtrain_exe),"svm-train executable not found" assert os.path.exists(dataset_pathname),"dataset not found" def range_f(begin,end,step): # like range, but works on non-integer too seq = [] while True: if step > 0 and begin > end: break if step < 0 and begin < end: break seq.append(begin) begin += step return seq def permute_sequence(seq): n = len(seq) if n <= 1: return seq mid = int(n/2) left = permute_sequence(seq[:mid]) right = permute_sequence(seq[mid+1:]) ret = [seq[mid]] while left or right: if left: ret.append(left.pop(0)) if right: ret.append(right.pop(0)) return ret def redraw(db,best_param,tofile=False): if len(db) == 0: return begin_level = round(max(x[2] for x in db)) - 3 step_size = 0.5 best_log2c,best_log2g,best_rate = best_param db.sort(key = lambda x:(x[0], -x[1])) prevc = db[0][0] for line in db: if prevc != line[0]: prevc = line[0] def calculate_jobs(): c_seq = permute_sequence(range_f(c_begin,c_end,c_step)) g_seq = permute_sequence(range_f(g_begin,g_end,g_step)) nr_c = float(len(c_seq)) nr_g = float(len(g_seq)) i = 0 j = 0 jobs = [] while i < nr_c or j < nr_g: if i/nr_c < j/nr_g: # increase C resolution line = [(c_seq[i],g_seq[k]) for k in range(j)] i += 1 jobs.append(line) else: # increase g resolution line = [(c_seq[k],g_seq[j]) for k in range(i)] j += 1 jobs.append(line) return jobs class WorkerStopToken: # used to notify the worker to stop pass class Worker(Thread): def __init__(self,name,job_queue,result_queue): Thread.__init__(self) self.name = name self.job_queue = job_queue self.result_queue = result_queue def run(self): while True: cexp,gexp = self.job_queue.get() if cexp is WorkerStopToken: self.job_queue.put((cexp,gexp)) # print 'worker %s stop.' % self.name break try: rate = self.run_one(2.0**cexp,2.0**gexp) if rate is None: raise "get no rate" except: # we failed, let others do that and we just quit traceback.print_exception(*sys.exc_info()) self.job_queue.put((cexp,gexp)) print('worker %s quit.' % self.name) break else: self.result_queue.put((self.name,cexp,gexp,rate)) class LocalWorker(Worker): def run_one(self,c,g): cmdline = '%s -c %s -g %s -v %s %s %s' % \ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) result = Popen(cmdline,shell=True,stdout=PIPE).stdout for line in result: if "Cross" in str(line): return float(line.split()[-1][0:-1]) class SSHWorker(Worker): def __init__(self,name,job_queue,result_queue,host): Worker.__init__(self,name,job_queue,result_queue) self.host = host self.cwd = os.getcwd() def run_one(self,c,g): cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \ (self.host,self.cwd, svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) result = Popen(cmdline,shell=True,stdout=PIPE).stdout for line in result: if "Cross" in str(line): return float(line.split()[-1][0:-1]) class TelnetWorker(Worker): def __init__(self,name,job_queue,result_queue,host,username,password): Worker.__init__(self,name,job_queue,result_queue) self.host = host self.username = username self.password = password def run(self): import telnetlib self.tn = tn = telnetlib.Telnet(self.host) tn.read_until("login: ") tn.write(self.username + "\n") tn.read_until("Password: ") tn.write(self.password + "\n") # XXX: how to know whether login is successful? tn.read_until(self.username) # print('login ok', self.host) tn.write("cd "+os.getcwd()+"\n") Worker.run(self) tn.write("exit\n") def run_one(self,c,g): cmdline = '%s -c %s -g %s -v %s %s %s' % \ (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname) result = self.tn.write(cmdline+'\n') idx,matchm,output = self.tn.expect(['Cross.*\n']) for line in output.splitlines(): if "Cross" in str(line): return float(line.split()[-1][0:-1]) def main(): # set parameters process_options() # put jobs in queue jobs = calculate_jobs() job_queue = Queue.Queue(0) result_queue = Queue.Queue(0) for line in jobs: for (c,g) in line: job_queue.put((c,g)) job_queue._put = job_queue.queue.appendleft # fire telnet workers if telnet_workers: nr_telnet_worker = len(telnet_workers) username = getpass.getuser() password = getpass.getpass() for host in telnet_workers: TelnetWorker(host,job_queue,result_queue, host,username,password).start() # fire ssh workers if ssh_workers: for host in ssh_workers: SSHWorker(host,job_queue,result_queue,host).start() # fire local workers for i in range(nr_local_worker): LocalWorker('local',job_queue,result_queue).start() # gather results done_jobs = {} result_file = open(out_filename, 'w') db = [] best_rate = -1 best_c1,best_g1 = None,None for line in jobs: for (c,g) in line: while (c, g) not in done_jobs: (worker,c1,g1,rate) = result_queue.get() done_jobs[(c1,g1)] = rate result_file.write('%s %s %s\n' %(c1,g1,rate)) result_file.flush() if rate > best_rate or (rate==best_rate and g1==best_g1 and c1<best_c1): best_rate = rate best_c1,best_g1=c1,g1 best_c = 2.0**c1 best_g = 2.0**g1 print("[%s] %s %s %s (best c=%s, g=%s, rate=%s)" % \ (worker,c1,g1,rate, best_c, best_g, best_rate)) db.append((c,g,done_jobs[(c,g)])) redraw(db,[best_c1, best_g1, best_rate]) redraw(db,[best_c1, best_g1, best_rate],True) job_queue.put((WorkerStopToken,None)) print("%s %s %s" % (best_c, best_g, best_rate)) main()