module APL.Eval ( eval, ) where import APL.AST (Exp (..)) import APL.Monad evalIntBinOp :: (Integer -> Integer -> EvalM Integer) -> Exp -> Exp -> EvalM Val evalIntBinOp f e1 e2 = do v1 <- eval e1 v2 <- eval e2 case (v1, v2) of (ValInt x, ValInt y) -> ValInt <$> f x y (_, _) -> failure "Non-integer operand" evalIntBinOp' :: (Integer -> Integer -> Integer) -> Exp -> Exp -> EvalM Val evalIntBinOp' f e1 e2 = evalIntBinOp f' e1 e2 where f' x y = pure $ f x y -- Replaced their eval with ours as instructed NOTE eval :: Exp -> EvalM Val eval (CstInt x) = pure $ ValInt x eval (CstBool b) = pure $ ValBool b eval (Var v) = do env <- askEnv case envLookup v env of Just x -> pure x Nothing -> failure $ "Unknown variable: " ++ v eval (Add e1 e2) = evalIntBinOp' (+) e1 e2 eval (Sub e1 e2) = evalIntBinOp' (-) e1 e2 eval (Mul e1 e2) = evalIntBinOp' (*) e1 e2 eval (Div e1 e2) = evalIntBinOp checkedDiv e1 e2 where checkedDiv _ 0 = failure "Division by zero" checkedDiv x y = pure $ x `div` y eval (Pow e1 e2) = evalIntBinOp checkedPow e1 e2 where checkedPow x y = if y < 0 then failure "Negative exponent" else pure $ x ^ y eval (Eql e1 e2) = do v1 <- eval e1 v2 <- eval e2 case (v1, v2) of (ValInt x, ValInt y) -> pure $ ValBool $ x == y (ValBool x, ValBool y) -> pure $ ValBool $ x == y (_, _) -> failure "Invalid operands to equality" eval (If cond e1 e2) = do cond' <- eval cond case cond' of ValBool True -> eval e1 ValBool False -> eval e2 _ -> failure "Non-boolean conditional." eval (Let var e1 e2) = do v1 <- eval e1 localEnv (envExtend var v1) $ eval e2 eval (Lambda var body) = do env <- askEnv pure $ ValFun env var body eval (Apply e1 e2) = do v1 <- eval e1 v2 <- eval e2 case (v1, v2) of (ValFun f_env var body, arg) -> localEnv (const $ envExtend var arg f_env) $ eval body (_, _) -> failure "Cannot apply non-function" eval (TryCatch e1 e2) = eval e1 `catch` eval e2 eval (Print s e1) = do v1 <- eval e1 case v1 of (ValInt i) -> do evalPrint (s++": "++(show i)) pure $ v1 (ValBool b) -> do evalPrint (s++": "++(show b)) pure $ v1 (ValFun _ _ _) -> do evalPrint (s++": #") pure $ v1 eval (KvPut e1 e2) = do v1 <- eval e1 v2 <- eval e2 evalKvPut v1 v2 pure $ v2 eval (KvGet e) = do v <- eval e evalKvGet v