;; --------------------------------------------------------------------
;; HybridSAL
;; Copyright (C) 2006, SRI International.  All Rights Reserved.
;; 
;; This program is free software; you can redistribute it and/or
;; modify it under the terms of the GNU General Public License
;; as published by the Free Software Foundation; either version 2
;; of the License, or (at your option) any later version.
;; 
;; This program is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;; GNU General Public License for more details.
;; 
;; You should have received a copy of the GNU General Public License
;; along with this program; if not, write to the Free Software
;; Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
;; --------------------------------------------------------------------

(in-package :sal)

;; (export '(typecheck-file typecheck tc-eq))

;; var to type mapping
(defvar *environment* nil)
(defvar *context* nil)

;; ------------------------------------------------------
;; tc-eq
;; ------------------------------------------------------
(defmethod tc-eq ((expr expression) (typename symbol))
  (typecheck* expr)
  (let ((dtype (declared-type expr)))
    (eq dtype typename)))

(defmethod tc-eq ((vdecl var-decl) (typename symbol))
  (assert (and (declared-type vdecl) typename))
  (eq (declared-type vdecl) typename))

(defmethod tc-eq ((expr t) (typename symbol))
  (sal-error expr "tc-eq called with non-expression~%"))
;; ------------------------------------------------------

;; ------------------------------------------------------
;; typecheck
;; ------------------------------------------------------
(defmethod typecheck ((vdecl var-decl) typename)
  (declare (ignore typename))
  (typecheck* vdecl))

(defmethod typecheck ((expr expression) typename)
  (declare (ignore typename))
  (if (not (declared-type expr))
      (typecheck* expr))
  (assert (declared-type expr)) 
  t)
  
(defmethod typecheck :after ((expr t) typename)
  (assert (declared-type expr))
  (if typename
      (if (eq (declared-type expr) typename) t 
	  (sal-error nil "Typecheck error~%")) t))

(defun env-type (id)
  (declare (special *environment*))
  (cdr (assoc id *environment*)))

(defmethod typecheck* ((expr numeral))
  (setf (declared-type expr) 'REAL))
  
(defmethod typecheck* ((expr name-expr))
  (setf (declared-type expr)
  (case (id expr)
	((TRUE FALSE) 'BOOLEAN)
	(t (env-type (id expr))))))

(defmethod typecheck* ((expr unary-application))
  (let ((op (op2symbol (operator expr))))
    (case op
     	  (-
	   (typecheck (args1 expr) 'REAL)
	   (setf (declared-type expr) 'REAL))
	  (NOT
	   (typecheck (args1 expr) 'BOOLEAN)
	   (setf (declared-type expr) 'BOOLEAN))
	  (t 
	   (sal-error expr "typecheck found unknown operator~%")))))

(defmethod typecheck* ((expr application))
  (let ((op (op2symbol (operator expr))))
    (case op
     	  ((AND OR IMPLIES IFF => <=>)
	   (typecheck (args1 expr) 'BOOLEAN)
	   (typecheck (args2 expr) 'BOOLEAN)
	   (setf (declared-type expr) 'BOOLEAN))
	  ((+ - * /)
	   (typecheck (args1 expr) 'REAL)
	   (typecheck (args2 expr) 'REAL)
	   (setf (declared-type expr) 'REAL))
	  ((< > <= >=)
	   (typecheck (args1 expr) 'REAL)
	   (typecheck (args2 expr) 'REAL)
	   (setf (declared-type expr) 'BOOLEAN))
	  ((= /=)
	   (typecheck* (args1 expr))
	   (typecheck (args2 expr) (declared-type (args1 expr)))
	   (setf (declared-type expr) 'BOOLEAN))
	  ((G F)
	   (typecheck (args1 expr) 'BOOLEAN)
	   (setf (declared-type expr) 'BOOLEAN))
	  (NOT
	   (typecheck (args1 expr) 'BOOLEAN)
	   (setf (declared-type expr) 'BOOLEAN))
	  (IF
	   (typecheck (args1 expr) 'BOOLEAN)
	   (typecheck* (args2 expr))
	   (typecheck (args expr 2) (declared-type (args2 expr)))
	   (setf (declared-type expr) (declared-type (args2 expr))))
	  (t 
	   (sal-error expr "typecheck found unknown operator~%")))))

(defmethod typecheck* ((vdecl var-decl))
  (declare (special *environment*))
  (let ((dtype (declared-type vdecl)))
    (assert dtype)
    (setf (declared-type vdecl) (id dtype))
    (if (env-type (id vdecl)) t
        (setf *environment* (acons (id vdecl) (id dtype) *environment*)))))
;; ------------------------------------------------------

;; ------------------------------------------------------
;;; Similar to pc-typecheck
(defun sal-typecheck-term (term expected)
  (sal-typecheck-term* term expected))

(defmethod sal-typecheck-term* ((term string) expected)
  (typecheck (sal-parse-term term) expected))

(defmethod sal-typecheck-term* ((term sal-syntax) expected)
  (typecheck term expected))
;; ------------------------------------------------------
  
;; ------------------------------------------------------
;; recursively descend context structure and call typecheck*
;; ------------------------------------------------------
(defmethod typecheck-file ((file string))
  (let ((ctx (sal-parse file)))
    (typecheck* ctx)))

(defmethod typecheck* ((c context))
  (declare (special *context*))
  (assert (every #'sal-declaration? (context-body c)))
  (sal-message "Typechecking context ~a" (id c))
  (setf *context* c)
  (loop for i in (context-body c) do (typecheck* i))
  c)

(defmethod typecheck* ((decl constant-declaration))
  (assert (declared-type decl)))
(defmethod typecheck* ((decl type-declaration))
  (sal-error nil "HybridSal-Light cannot handle type declarations"))
(defmethod typecheck* ((decl module-declaration))
  (setf (modref (module decl)) 
	(mk-sal-moduleinstance nil (mk-sal-modulename nil (id decl)) nil))
  (typecheck* (module decl)))
(defmethod typecheck* ((decl assertion-declaration))
  (typecheck* (assertion decl)))
(defmethod typecheck* ((a module-models))
  (typecheck* (module a))	;; module-instance
  (typecheck* (assertion a)))

(defmethod typecheck* ((type scalar-type))
  (sal-message "typecheck* scalar-type"))

(defmethod typecheck* ((mod base-module))
  (unless (typechecked? mod)
    (setf (typechecked? mod) t)
    (let* ((*current-sal-module* mod)
	   (state-vars (state-var-decls mod)))
      (declare (special *current-sal-module*))
      (loop for i in state-vars do (typecheck* i))
      (setf (input-decls mod)
	    (remove-if (complement #'input-decl?)
	      (declarations mod)))
      (setf (output-decls mod)
	    (remove-if (complement #'output-decl?)
	      (declarations mod)))
      (setf (global-decls mod)
	    (remove-if (complement #'global-decl?)
	      (declarations mod)))
      (setf (local-decls mod)
	    (remove-if (complement #'local-decl?)
	      (declarations mod)))
      (loop for i in (def-decls mod) do (typecheck* i))
      (loop for i in (invar-decls mod) do (typecheck* i))
      (loop for i in (initfor-decls mod) do (typecheck* i))
      (loop for i in (init-decls mod) do (typecheck* i))
      (loop for i in (trans-decls mod) do (typecheck* i))))
  mod)

(defmethod state-var-decls ((mod base-module))
  (remove-if-not #'state-var-decl? (declarations mod)))

(defmethod state-var-decls ((mod module-instance))
  (state-var-decls (module (mod-decl (mod-name mod)))))

(defmethod state-var-decls ((mod composition))
  (let* ((m1 (module1 mod))
	 (m2 (module2 mod))
	 (in1 (input-vardecls m1))
	 (out1 (output-vardecls m1))
	 (glob1 (global-vardecls m1))
	 (loc1 (local-vardecls m1))
	 (in2 (input-vardecls m2))
	 (out2 (output-vardecls m2))
	 (glob2 (global-vardecls m2))
	 (loc2 (local-vardecls m2)))
    (append (set-difference (union in1 in2 :key #'id)
			    (union out1 out2 :key #'id)
			    :key #'id)
	    out1 out2 glob1 glob2 loc1 loc2)))

(defmethod def-decls ((mod base-module))
  (remove-if-not #'def-decl? (declarations mod)))

(defmethod init-decls ((mod base-module))
  (remove-if-not #'init-decl? (declarations mod)))

(defmethod trans-decls ((mod base-module))
  (remove-if-not #'trans-decl? (declarations mod)))

(defmethod typecheck* ((invar invar-decl))
  (typecheck* (expression invar)))

(defmethod typecheck* ((initfor initfor-decl))
  (typecheck* (expression initfor)))

(defmethod invar-decls ((mod base-module))
  (remove-if-not #'invar-decl? (declarations mod)))

(defmethod initfor-decls ((mod base-module))
  (remove-if-not #'initfor-decl? (declarations mod)))

(defmethod typecheck* ((decl def-decl))
  (typecheck* (definitions decl)))

(defmethod typecheck* ((def simple-definition))
  (typecheck* (lhs def))
  (typecheck* (rhs-definition def)))

(defmethod typecheck* ((def forall-definition))
  (sal-error nil "HybridSal-Light cannot handle forall definitions"))
(defmethod valid-lhs? ((ex t))
  (sal-error nil "HybridSal-Light cannot handle this specification"))
(defmethod typecheck* ((init init-decl))
  (loop for i in (definitions-or-commands init)
       do (typecheck* i)))
(defmethod typecheck* ((trans trans-decl))
  (dolist (d (definitions-or-commands trans))
    (typecheck* d)))

(defmethod typecheck* ((ex next-operator))
  (typecheck* (name ex))
  (setf (declared-type ex) (declared-type (name ex))))

(defmethod typecheck* ((cmd guarded-command))
  (typecheck* (guard cmd))
  (loop for i in (assignments cmd) do (typecheck* i)))

(defmethod typecheck* ((g guard))
  (typecheck* (expression g)))

(defmethod typecheck* ((mod asynchronous-composition))
  (typecheck-composition mod))

(defmethod typecheck* ((mod synchronous-composition))
  (typecheck-composition mod))

(defun typecheck-composition (mod)
  (let ((m1 (typecheck* (module1 mod)))
	(m2 (typecheck* (module2 mod))))
    (let ((common-outs (intersection (output-decls m1)
				     (output-decls m2)
				     :key #'id)))
      (when common-outs
	(sal-error mod
	  "Output variables must be distinct")))
      mod))

;;; RENAME a TO x, b TO y IN mod
;;; state-type is mod.STATE - {a, b}
(defmethod typecheck* ((mod renaming))
  (typecheck* (module mod))
  (typecheck* (renames mod))
  (setf (input-decls mod)
	(remove-if (complement #'input-decl?)
	  (remove-duplicates
	      (input-decls (module mod)))))
  (setf (output-decls mod)
	(remove-if (complement #'output-decl?)
	  (remove-duplicates
	      (output-decls (module mod)))))
  (setf (global-decls mod)
	(remove-if (complement #'global-decl?)
	  (remove-duplicates
	      (global-decls (module mod)))))
  (setf (local-decls mod)
	(remove-if (complement #'local-decl?)
	  (remove-duplicates
	      (local-decls (module mod)))))
  mod)

(defmethod typecheck* ((rn rename))
  (typecheck* (lhs rn))
  (typecheck (rhs rn) (declared-type (lhs rn))))

;;; WITH INPUT a; OUTPUT b mod
;;; state-type is mod.STATE + {a, b}
(defmethod typecheck* ((mod with-module))
  (typecheck* (new-var-decls mod))
  (let ((*with-variables* (new-var-decls mod)))
    (typecheck* (module mod))
    (setf (input-decls mod)
	    (remove-if (complement #'input-decl?)
	      (remove-duplicates
		  (append *with-variables* (input-decls (module mod))))))
      (setf (output-decls mod)
	    (remove-if (complement #'output-decl?)
	      (remove-duplicates
		  (append *with-variables* (output-decls (module mod))))))
      (setf (global-decls mod)
	    (remove-if (complement #'global-decl?)
	      (remove-duplicates
		  (append *with-variables* (global-decls (module mod))))))
      (setf (local-decls mod)
	    (remove-if (complement #'local-decl?)
	      (remove-duplicates
		  (append *with-variables* (local-decls (module mod))))))
      mod))

(defmethod typecheck* ((mod multi-synchronous))
  (call-next-method))

(defmethod typecheck* ((mod multi-asynchronous))
  (call-next-method))

(defmethod typecheck* ((mod multi-composition))
  (typecheck* (var-decl mod))
  (typecheck* (module mod)))

;;; The following just follow the definitions given in the manual It
;;; assumes that typechecking has already been done.  For example, the
;;; output-vardecls is just the union of the component output-vardecls for
;;; compositions, but since typechecking will guarantee that they are
;;; disjoint, we simply append them.

(defmethod output-vardecls ((mod base-module))
  (remove-if (complement #'output-decl?) (declarations mod)))

(defmethod output-vardecls ((mod module-instance))
  (output-vardecls (module (mod-decl (mod-name mod)))))

(defmethod output-vardecls ((mod composition))
  (append (output-vardecls (module1 mod))
	  (output-vardecls (module2 mod))))

(defmethod input-vardecls ((mod base-module))
  (remove-if (complement #'input-decl?) (declarations mod)))

(defmethod input-vardecls ((mod module-instance))
  (input-vardecls (module (mod-decl (mod-name mod)))))

(defmethod input-vardecls ((mod composition))
  (let ((outputs (output-vardecls mod))
	(globals (global-vardecls mod)))
    (remove-if #'(lambda (v) (or (member (id v) outputs :key #'id)
				 (member (id v) globals :key #'id)))
      (append (input-vardecls (module1 mod))
	      (input-vardecls (module2 mod))))))

(defmethod input-vardecls ((mod with-module))
  (append (input-decls mod)
	  (input-vardecls (module mod))))

(defmethod input-vardecls ((mod renaming))
  (break))

(defmethod global-vardecls ((mod base-module))
  (remove-if (complement #'global-decl?) (declarations mod)))

(defmethod global-vardecls ((mod module-instance))
  (global-vardecls (module (mod-decl (mod-name mod)))))

(defmethod global-vardecls ((mod composition))
  (union (global-vardecls (module1 mod))
	 (global-vardecls (module2 mod))
	 :key #'id))

(defmethod local-vardecls ((mod base-module))
  (remove-if (complement #'local-decl?) (declarations mod)))

(defmethod local-vardecls ((mod module-instance))
  (local-vardecls (module (mod-decl (mod-name mod)))))

(defmethod local-vardecls ((mod composition))
  (append (local-vardecls (module1 mod))
	  (local-vardecls (module2 mod))))

(defmethod controlled-vardecls ((mod base-module))
  (remove-if (complement #'controlled-var-decl?) (declarations mod)))

(defmethod controlled-vardecls ((mod module-instance))
  (controlled-vardecls (module (mod-decl (mod-name mod)))))

(defmethod controlled-vardecls ((mod composition))
  (append (output-vardecls mod)
	  (global-vardecls mod)
	  (local-vardecls mod)))

(defmethod observed-vardecls ((mod base-module))
  (remove-if (complement #'observed-var-decl?) (declarations mod)))

(defmethod observed-vardecls ((mod module-instance))
  (observed-vardecls (module (mod-decl (mod-name mod)))))

(defmethod observed-vardecls ((mod composition))
  (append (input-vardecls mod)
	  (global-vardecls mod)))

(defun mapappend (fun list)
  (mapcan #'copy-list (mapcar fun list)))

(defmethod definitions ((mod base-module))
  (mapappend #'definitions
	     (remove-if (complement #'def-decl?) (declarations mod))))

(defmethod definitions ((mod module-instance))
  (definitions (module (mod-decl (mod-name mod)))))

(defmethod initializations ((mod base-module))
  (mapappend #'definitions-or-commands
	     (remove-if (complement #'init-decl?) (declarations mod))))

(defmethod initializations ((mod module-instance))
  (initializations (module (mod-decl (mod-name mod)))))

(defmethod transitions ((mod base-module))
  (mapappend #'definitions-or-commands
	     (remove-if (complement #'trans-decl?) (declarations mod))))

(defmethod transitions ((mod module-instance))
  (transitions (module (mod-decl (mod-name mod)))))

(defmethod module ((mod observe-module))
  (module1 mod))

(defmethod observer ((mod observe-module))
  (module2 mod))

(defmethod typecheck* ((mod observe-module))
  (sal-error mod "Code for typechecking observe-module missing~%"))

(defmethod lhs* ((d simple-definition))
  (lhs* (lhs d)))

(defmethod lhs* ((ex name-expr))
  ex)

(defmethod typecheck* ((modinst module-instance))
  (declare (special *context*))
  (let* ((mid (id (mod-name modinst)))
	 (mdecl	(find mid (context-body *context*) :key #'id)))
    (if (not mdecl)
        (sal-error nil "Typecheck error: No module ~%a found~%" mid)
	(setf (mod-decl (mod-name modinst)) mdecl))
    modinst))

(defmethod typecheck* ((name qualified-name-expr))
  (sal-error nil "HybridSal-Light cannot handle qualified-name-expr"))

(defun sal-message (ctl &rest args)
  (format nil "Sal Message: ~?~%" ctl args))

