from avr import *
import struct, sys, time 

(AVR_INFRA_BUFFER_EMPTY, AVR_INFRA_GET_CODE, AVR_DDR_SET,
AVR_DDR_GET, AVR_PORT_SET, AVR_PORT_GET, AVR_PIN_GET, AVR_EEPROM_READ,
AVR_EEPROM_WRITE, AVR_RS232_WRITE, AVR_RS232_READ, AVR_RS232_BAUD_SET,
AVR_RS232_BAUD_GET, AVR_SPI, AVR_SPI4) = range(1, 16)

debug = 1

class error(Exception): pass

def frombinary(s):
    ret = 0
    for c in s:
        ret = ret * 2 + (c == '1')
    return ret

class Template:
    def __init__(self, v, offsets=(), response_shift=0, response_mask=0xff):
        self.v = v
        self.offsets = offsets
        self.response_mask = response_mask
        self.response_shift = response_shift

    def fill(self, *values):
        v = self.v
        for (shift, mask), value in zip(self.offsets, values):
            v = v | ((value & mask) << shift)
        return v

    def extract(self, resp):
        return (resp >> self.response_shift) & self.response_mask

    def fromstr(cls, s):
        s = ''.join(s.split())
        v = frombinary(s)
        w1 = s.count('a') + s.count('b'); o1 = 31-max(s.rfind('a'), s.rfind('b'))
        w2 = s.count('i')               ; o2 = 31-s.rfind('i')
        w3 = s.count('o')               ; o3 = 31-s.rfind('o')
        
        args = []
        if w1: args.append((o1, (1<<w1)-1))
        if w2: args.append((o2, (1<<w2)-1))
        if w3:
            return cls(v, args, o3, (1<<w3)-1)
        else:
            return cls(v, args)
    fromstr = classmethod(fromstr)

class Programmer:
    programming_enable = Template(0xAC530000, (), 8)
    read_signature = Template(0x30000000, ((8, 0x3),))

    def __init__(self, dev=None):
        if dev is None: dev = Avr()
        self.dev = dev
        self.send_command(AVR_DDR_SET, 0xff)

    def send_command(self, *args):
        return self.dev.send_command(*args)

    def get_response(self, *args):
        return self.dev.get_response(*args)

    def reset(self, state=1):
        p = ord(self.get_response(AVR_PORT_GET, 0))
        q = p | 0x10
        self.send_command(AVR_PORT_SET, q)
        if state:
            self.send_command(AVR_PORT_SET, p)

    def led(self, state=1):
        p = ord(self.get_response(AVR_PORT_GET, 0))
        if state:
            q = p | 0x4
        else:
            q = p & ~0x4
        self.send_command(AVR_PORT_SET, q)

    def send_spi(self, template, *args):
        s = template.fill(*args)
        b = [ord(c) for c in struct.pack(">i", s)]
        resp = self.get_response(AVR_SPI4, b[0] + b[1] * 256, b[2] + b[3] * 256)
        resp = struct.unpack(">i", resp)[0]
        ret = template.extract(resp)
        if debug:
            print "%08x... %08x [%02x]" % (s, resp, ret)
        return ret

    def find_avr(self, read_signature=1):
        for i in range(32):
            self.reset()
            r = self.send_spi(self.programming_enable)
            if r == 0x53: break
        else:
            raise error, "Device not responding"
        if not read_signature: return
        self.led(1)
        signature = [self.send_spi(self.read_signature, i) for i in range(3)]
        self.detected_signature = (signature[0] << 16) + (signature[1] << 8) + signature[2]
        print "Device detected.  Signature = %06x" % self.detected_signature
            
    def release(self):
        self.send_command(AVR_DDR_SET, 0)
        self.send_command(AVR_PORT_SET, 0)

    def specialize(self, sig=None):
        if sig is None: sig = self.detected_signature
        try:
            self.__class__ = all_programmers[sig]
        except:
            print "Unknown device.  Signature = %06x" % sig
        else:
            print "Device identification: %s" % (self.__class__.__name__)

    def erase(self):
        self.send_spi(self.chip_erase)

    def verify_program(self, bytes):
        print "Verifying program"
        for i in range(len(bytes)):
            if i % 2:
                r = self.send_spi(self.read_program_hi, i/2)
            else:
                r = self.send_spi(self.read_program_lo, i/2)
            if r != bytes[i]:
                raise error, "Mismatch at byte %d: expected %02x, read %02x" % (i, bytes[i], r)

class unbanked_memory: pass
class unbanked_eeprom: pass

class banked_memory: 
    def write_program(self, bytes):
        print "Uploading program"
        for r in range(0, len(bytes), self.program_pagesize):
            for i in range(self.program_pagesize):
                if r+i >= len(bytes): break
                b = bytes[r + i]
                if i % 2:
                    self.send_spi(self.load_page_hi, i/2, b)
                else:
                    self.send_spi(self.load_page_lo, i/2, b)
            self.send_spi(self.write_page, r / self.program_pagesize)
            time.sleep(self.delay_write_page)
class banked_eeprom: pass

class At90s2313(Programmer, unbanked_memory, unbanked_eeprom):
    signature = 0x1e9101
    read_program_lo = Template(0x20000000, [(8, 0x3ff)])
    read_program_hi = Template(0x28000000, [(8, 0x3ff)])

    write_program_lo = Template(0x40000000, [(8, 0x3ff), (0, 0xff)])
    write_program_hi = Template(0x48000000, [(8, 0x3ff), (0, 0xff)])

    read_eeprom = Template(0xa0000000, [(8, 0x7f)])
    write_eeprom = Template(0xc0000000, [(8, 0x7f)])
    
    write_lock = Template(0xace00000, [(17, 0x3)])

    chip_erase = Template(0xac800000)

class Atmega16(Programmer, banked_memory, unbanked_eeprom):
    signature = 0x1e9403

    program_pagesize = 1<<7
    read_program_lo = Template(0x20000000, [(8, 0x1fff)])
    read_program_hi = Template(0x28000000, [(8, 0x1fff)])

    load_page_lo = Template(0x40000000, [(8, 0x3f), (0, 0xff)])
    load_page_hi = Template(0x48000000, [(8, 0x3f), (0, 0xff)])
    write_page   = Template(0x4c000000, [(14, 0x7f)])
    delay_write_page = 4.5/1000

    read_eeprom = Template(0xa0000000, [(8, 0x1ff)])
    write_eeprom = Template(0xc0000000, [(8, 0x1ff)])

    read_lock = Template(0x58000000, [], 0, 0x3f)
    write_lock = Template(0xace000c0, [(0, 0x3f)])
    chip_erase = Template(0xac800000)

def read_ihex(f):
    memory = {}
    for line in f:
        line = line.strip()
        if not line.startswith(":"): continue
        bytecount = int(line[1:3], 16)
        address = int(line[3:7], 16)
        rectype = int(line[7:9], 16)
        data = line[9:-2]
        checksum = int(line[-3:], 16)
        # should verify checksum
        if rectype > 1:
            error, "Unsupprorted record type 0x%02x" % rectype
        if rectype == 1:
            break
        for i in range(address, address + len(data)/2):
            j = 2 * (i - address)
            memory[i] = int(data[j:j+2], 16)
    l = [0] * (max(memory.keys()) + 1)
    for k, v in memory.items(): l[k] = v
    return l

def make_all_programmers():
    global all_programmers
    all_programmers = {}
    for v in globals().values():
        if v is Programmer: continue
        if isinstance(v, type(Programmer)) and issubclass(v, Programmer):
            all_programmers[v.signature] = v

make_all_programmers()

if debug:
    print "Known devices:"
    for k, v in all_programmers.items():
        print "  %06x: %s" % (k, v.__name__)

def main(argv):
    p = Programmer()
    p.find_avr()
    p.specialize()
    p.erase()
    p.find_avr(0)
    f = open(argv[1])
    i = read_ihex(f)
    p.write_program(i)
    p.verify_program(i)
    p.release()

if __name__ == '__main__':
    main(sys.argv)


