cpu.c 6.85 KB
//
// Created by imanol on 12/25/16.
//

#include <stdlib.h>
#include "cpu.h"

void mov(uint16_t dst, uint16_t src)
{
    if(src > MAX_INT)
    {
        regs[dst % (MAX_INT+1)] = regs[src % (MAX_INT+1)];
    }
    else
    {
        regs[dst % (MAX_INT+1)] = src;
    }
}

void push(uint16_t src)
{
    if(src > MAX_INT)
    {
        stack_push(regs[src % (MAX_INT+1)]);
    }
    else
    {
        stack_push(src);
    }
}

void pop(uint16_t dst)
{
    uint16_t value = stack_pop();
    if(STACK_FAULT)
    {
        fprintf(stderr,"CRITICAL: STACK VIOLATION\n");
        core_dump();
    }
    regs[dst % (MAX_INT+1)] = value;
}

void teq(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (uint16_t)(a == b);
}

void tgt(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (uint16_t)(a > b);
}

void jmp(uint16_t dst)
{
    if(dst > MAX_INT)
    {
        pc = regs[dst % (MAX_INT+1)];
    }
    else
    {
        pc = dst;
    }
}

void jnz(uint16_t cond, uint16_t dst)
{
    if(cond > MAX_INT)
    {
        cond = regs[cond % (MAX_INT+1)];
    }
    if(cond)
    {
        jmp(dst);
    }
}

void jz(uint16_t cond, uint16_t dst)
{
    if(cond > MAX_INT)
    {
        cond = regs[cond % (MAX_INT+1)];
    }
    if(!cond)
    {
        jmp(dst);
    }
}

void add(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (uint16_t)((a + b) % (MAX_INT+1));
}

void mul(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (uint16_t)((a * b) % (MAX_INT+1));
}

void mod(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (a % b);
}

void and(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (a & b);
}

void or(uint16_t dst, uint16_t a, uint16_t b)
{
    if(a > MAX_INT)
    {
        a = regs[a % (MAX_INT+1)];
    }
    if(b > MAX_INT)
    {
        b = regs[b % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (a | b);
}

void not(uint16_t dst, uint16_t src)
{
    if(src > MAX_INT)
    {
        src = regs[src % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = (uint16_t)(~src & 0x7FFF);
}

void load(uint16_t dst, uint16_t src)
{
    if(src > MAX_INT)
    {
        src = regs[src % (MAX_INT+1)];
    }
    regs[dst % (MAX_INT+1)] = mem[src];
}

void stor(uint16_t dst, uint16_t src)
{
    if(src > MAX_INT)
    {
        src = regs[src % (MAX_INT+1)];
    }
    if(dst > MAX_INT)
    {
        mem[regs[dst % (MAX_INT+1)]] = src;
    }
    else
    {
        mem[dst] = src;
    }
}

void call(uint16_t dst)
{
    stack_push(pc);
    jmp(dst);
}

uint8_t ret()
{
    pc = stack_pop();
    return STACK_FAULT;
}

void out(uint16_t src)
{
    if(src > MAX_INT)
    {
        src = regs[src % (MAX_INT+1)];
    }
    putchar(src);
}

void in(uint16_t dst)
{
    if(!input_buffer_size())
    {
        read_input();
    }
    char c = get_input();
    regs[dst % (MAX_INT + 1)] = (uint16_t)c;
}

void nop()
{
    return;
}

uint16_t fetch()
{
    uint16_t value = mem[pc++];
    return value;
}


void decode_instruction(uint16_t opcode, uint16_t *arg1, uint16_t *arg2, uint16_t *arg3)
{
    switch(opcode)
    {
        case HALT:
        case RET:
        case NOP:
            break;
        case PUSH:
        case POP:
        case JMP:
        case CALL:
        case OUT:
        case IN:
            *arg1 = fetch();
            break;
        case MOV:
        case JNZ:
        case JZ:
        case NOT:
        case LOAD:
        case STOR:
            *arg1 = fetch();
            *arg2 = fetch();
            break;
        case TEQ:
        case TGT:
        case ADD:
        case MUL:
        case MOD:
        case AND:
        case OR:
            *arg1 = fetch();
            *arg2 = fetch();
            *arg3 = fetch();
            break;
        default:
            break;
    }
}

uint8_t execute_instruction(uint16_t opcode, uint16_t arg1, uint16_t arg2, uint16_t arg3)
{
    switch(opcode)
    {
        case HALT:
            return 1;
        case MOV:
            mov(arg1,arg2);
            break;
        case PUSH:
            push(arg1);
            break;
        case POP:
            pop(arg1);
            break;
        case TEQ:
            teq(arg1,arg2,arg3);
            break;
        case TGT:
            tgt(arg1,arg2,arg3);
            break;
        case JMP:
            jmp(arg1);
            break;
        case JNZ:
            jnz(arg1,arg2);
            break;
        case JZ:
            jz(arg1,arg2);
            break;
        case ADD:
            add(arg1,arg2,arg3);
            break;
        case MUL:
            mul(arg1,arg2,arg3);
            break;
        case MOD:
            mod(arg1,arg2,arg3);
            break;
        case AND:
            and(arg1,arg2,arg3);
            break;
        case OR:
            or(arg1,arg2,arg3);
            break;
        case NOT:
            not(arg1,arg2);
            break;
        case LOAD:
            load(arg1,arg2);
            break;
        case STOR:
            stor(arg1,arg2);
            break;
        case CALL:
            call(arg1);
            break;
        case RET:
            return ret();
        case OUT:
            out(arg1);
            break;
        case IN:
            in(arg1);
            break;
        case NOP:
            nop();
            break;
        default:
            fprintf(stderr,"CRITICAL: UNKNOWN OPCODE %2x FOUND IN %2x\n",opcode,pc);
            core_dump();
            break;
    }
    return 0;
}

void start_execution()
{
    uint16_t arg1;
    uint16_t arg2;
    uint16_t arg3;

    for(;;)
    {
        uint16_t opcode = fetch();
        decode_instruction(opcode, &arg1, &arg2, &arg3);
        if(execute_instruction(opcode, arg1, arg2, arg3))
        {
            return;
        }
    }
}

void core_dump()
{
    print_regs();
    uint16_t memdump[MEMSIZE];
    uint16_t *stackdump;
    dump_memory(memdump);
    uint32_t stacksize = stack_dump(&stackdump);
    
    FILE *fpm = fopen("synacor.mem","w");
    FILE *fps = fopen("synacor.stack","w");
    fwrite(memdump,MEMSIZE,sizeof(uint16_t),fpm);
    fwrite(stackdump,stacksize,sizeof(uint16_t),fps);
    fclose(fpm);
    fclose(fps);
    
    exit(1);
}