
open Syntax
open Transform
open Metadata
open Syntaxutil
open Names
open Util
open Errors
open Scanf
open Parseutils
open Prettyutil
open Oldquery
open Ctojekyll

module SH = Data.StringCols.Hash

exception CantDecode


(* ------------------------------------------------------------
 * Util functions
 * ------------------------------------------------------------ *)

let cantdecode m str = raise CantDecode

let terror m str = fatal m.span str 

(* expand an expression into a field reference, if it is one *)
let unpack_fieldref (m,e) = match e with
	| Field (false,(_,Parens inner),id) -> Some (inner,id)
	| Field (true,inner,id) -> 
		Some((m,FunCall((cpm m,Var (cpm m,"*")),[inner])),id)
	| _ -> None 

(* is the expression a call to a function with the given name *)
let rec is_call_to name expr = match expr with
	| FunCall ((_,Var id),_) when id_str id = name -> true
	| Cast (_,(_,e)) -> is_call_to name e
	| _ -> false
	

(* ------------------------------------------------------------
 * Find info in a boilerplate expression
 * ------------------------------------------------------------ *)

let rec simplify_expr (m,expr) = match expr with
	| Parens e -> simplify_expr e
	| Cast (_,e) -> simplify_expr e
	| _ -> m,expr

let u_1call name f expr = match simplify_expr expr with
	| m,FunCall ((_,Var id),[e]) when id_str id = name -> f e
	| _ -> None

let u_addr f = u_1call "&" f

let u_field f expr = match simplify_expr expr with
	| m,Field (true,e,id) -> f e id
	| m,FunCall ((_,Var(_,"*")),[_,Field(false,e,id)]) -> f e id
	| _ -> None
	
let u_namekind nametest f expr = match simplify_expr expr with
	| m,Var id when nametest id -> f id
	| _ -> None

let u_subst_name hash nametest namestr = 
	u_namekind nametest (fun id -> 
		match SH.maybe_find (id_str id) hash with
		| Some e -> Some e
		| None -> pfatal (id_meta id) 
			(str "No such" <+> str namestr <++> str "expression available")
		)

let rec match_first expr pats = match pats with
	| [] -> expr
	| f::fs -> begin 
		match f expr with
			| Some x -> x
			| None -> match_first expr fs
		end		

let subst_expr_for_var hash nametest namestr expr = match_first expr [
	u_subst_name hash nametest namestr 
	]
	

(* ------------------------------------------------------------
 * Closure variable <-> "*_cenv->var"
 * ------------------------------------------------------------ *)

let decode_closure_vars (m,expr) = match expr with
	| FunCall ((_,Var(_,"*")),[_,Field(true,(_,Var(id_b)),id_v)])
		when is_thisenv_name id_b ->
		m,Var(id_v)
	| _ -> (m,expr)


(* ------------------------------------------------------------
 * Indirect dictionary method <-> dictname->name
 * ------------------------------------------------------------ *)

let decode_method_call (m,expr) = match expr with
	| Field(_,(_,Var(id_d)),id_m) when is_env_name id_d ->
		m,Var(id_m)
	| Field(true,(_,Field(true,(_,Var(id_b)),id_d)),id_m) 
		when is_env_name id_d && is_env_name id_b ->
		m,Var(id_m)
	| _ -> (m,expr)


(* ------------------------------------------------------------
 * Dictionary method <-> tyname_methodname
 * ------------------------------------------------------------ *)
	
let decode_method_name id = 	
	try
		Scanf.sscanf (id_str id) "%[^'_']_%[^'_']"
		(fun tyname methodname ->
			if (is_struct tyname || is_basetyp_name tyname)
			  && is_method methodname then
			fst id,methodname
			else id
		)
	with _ -> id

let find_iface_methods (id,funsigs as ifdef) =
	List.iter (fun d -> add_method (id_str (simpledecl_name d))) funsigs; 
	ifdef


(* ------------------------------------------------------------
 * Find type params for a struct
 * ------------------------------------------------------------ *)

let fill_in_struct_typarams ctyp = match ctyp with
	| TyStruct (kind,Some name,Some (m,[],ds)) ->
		let tyvars = old_struct_tyvars name in
		TyStruct (kind,Some name,Some(m,tyvars,ds))
	| _ -> ctyp


(* ------------------------------------------------------------
 * tagged union <-> struct with a _tag field
 * ------------------------------------------------------------ *)
		
let is_tag d = List.exists is_tag_name (decl_names d)	

let check_tagged_enum (m,enum) = match enum with
	| (_,TyEnum (None,Some tags)),_ -> List.map fst tags
	| _ -> mfatal m "The tag of a tagged struct must be an enum"

let strip_taggedvoid (m,((specs,core),init_decls) as decl) =  
	match list_extract STaggedOnly specs with
	| Some newspecs -> m,((newspecs,core),init_decls)
	| None -> decl

let check_tagged_union (m,union) = match union with
	| (_,TyStruct (SKUnion,None,Some (_,[],fields))),[([],Some id),None]
			when id_str id = "_body"
				-> fields
	| _ -> 	mfatal m 
	"The body of a tagged struct must be an anonymous struct called \"_body\""

let decl_has_name name decl = 
	is_simpledecl decl && id_equal (simpledecl_name decl) name

let field_for_name fields name= 
	match list_first_match (decl_has_name name) fields with
	| Some decl -> decl
	| None -> (nometa (),(([SCoreHere],TyVoid),[([],Some name),None]))

let process_tagged_details (m,tyvars,fields) = match fields with
	| [enum;union] -> 
		let tag_names = check_tagged_enum enum in
		let union_fields = check_tagged_union union in
		List.iter add_constructor (List.map id_str tag_names);
		let fields = List.map (field_for_name union_fields) tag_names in
		m,tyvars,fields
	| _ ->
		mfatal m "The fields of a tagged struct must be a tag and a union"	
	
let decode_tagged_struct ctyp = match ctyp with
	| TyStruct (SKStruct,Some name,Some (_,tyargs,d::ds as details))
		when old_tagged name ->
		let ctyp = TyStruct (SKTagged,Some name,
				Some (process_tagged_details details)) in
		Parseutils.register_connames ctyp;
		ctyp
	| TyStruct (_,Some name,None)
		when old_tagged name ->
		TyStruct (SKTagged,Some name,None)
	| _ -> ctyp


(* ------------------------------------------------------------
 * Remove hidden statements
 * ------------------------------------------------------------ *)

let definers : metadata SH.t = SH.create ()

let add_definer (m,_ as decl) = if is_simpledecl decl then 
	SH.add (id_str (simpledecl_name decl)) m definers 
	
let hide_definer id = if SH.mem (id_str id) definers then
	add_tag (SH.find (id_str id) definers) Hidden

let rec remove_hidden_slist stmts = match stmts with
	| (m,_) :: rest when has_tag m Hidden -> 
		remove_hidden_slist rest
	| (m,Case (p,(m2,_)))::s::rest when has_tag m2 Hidden ->
		remove_hidden_slist ((m,Case (p,s))::rest)
	| (m,Label (id,(m2,_)))::s::rest when has_tag m2 Hidden ->
		remove_hidden_slist ((m,Label (id,s))::rest)
	| stmt :: rest -> stmt :: remove_hidden_slist rest
	| [] -> []
	
let remove_hiddendecls = List.filter (fun (m,d) -> not (has_tag m Hidden))

let remove_hiddens (decls,stmts) = 
	remove_hiddendecls decls,remove_hidden_slist stmts


(* ------------------------------------------------------------
 * Constructor pattern <-> Constant pattern using tag constant
 * ------------------------------------------------------------ *)

let decode_conpat (m,stmt) = match stmt with
	| Case (PConst (_,Var tid),
		(m2,SExp (_,Assign((m3,Var vid),"=",(m4,Field (_,_,fid)))) as s )) 
		when is_constructor (id_str tid) && id_equal tid fid ->
		add_tag m2 Hidden;
		hide_definer vid;
		m,Case (PTag (tid,Some vid),s)
	| Case (PConst (_,Var tid),s) when is_constructor (id_str tid) ->
		m,Case (PTag (tid,None),s)
	| _ -> m,stmt
		

(* ------------------------------------------------------------
 * Tagged scrutinee <-> x._tag or x->_tag
 * ------------------------------------------------------------ *)

let decode_tagged_scrut e = match unpack_fieldref e with
	| Some (inner,id) when id_str id = "_tag" -> inner
	| _ -> e

let decode_tagged_switch (m,stmt) = match stmt with
	| Switch (e,blk) -> m,Switch (decode_tagged_scrut e,blk)
	| _ -> m,stmt

	
(* ------------------------------------------------------------
 * Methods in a Dictionary should have the correct kind
 * ------------------------------------------------------------ *)

let set_fundef_dict ifaceid tyname tyiface (m,(decl,krs,block)) = 
	m,(as_dictmethod ifaceid tyname tyiface decl,krs,block)

let set_dict_method_kinds (typ,iface,body as ddef) = match body with
	| DictImpl (iflist,mdefs) -> 
		let tname = typ_name_nocheck typ in
		(typ,iface, 
		DictImpl (iflist,List.map (set_fundef_dict iface tname iflist) mdefs))
	| DictProto _ -> ddef


(* ------------------------------------------------------------
 * Temporary variable <-> variable assignments in same block
 * ------------------------------------------------------------ *)
	 
let rec expr_to_path (m,expr) = match expr with
	| Var id -> [id_str id]
	| Field (false,inner,fieldid) -> id_str fieldid :: expr_to_path inner
	| Field (true,inner,fieldid) -> id_str fieldid :: "*" :: expr_to_path inner
	| Parens inner -> expr_to_path inner
	| FunCall ((_,Var id),[inner]) when id_str id = "*" ->
			"*" :: expr_to_path inner
	| _ -> []
		 
let ti m i = Some (cpm m, Init (m,i))

let rec tmpdef_for_stmt (m,stmt) = match stmt with
	| SExp (_,Assign (lhs,"=",rhs)) -> Some (m,expr_to_path lhs, rhs)
	| Label (_,s) -> tmpdef_for_stmt s
	| Case (_,s) -> tmpdef_for_stmt s
	| _ -> None	

let tmpdef_for_path path (m,p,e) = 
	if list_suffix path p then 
		Some (m,list_strip_suffix path p,e) 
	else 
		None

let rec pprint_path path = match path with
	| "*"::xs -> str "*" <+> parens (pprint_path xs)
	| x::"*"::xs -> pprint_path xs <+> str "->" <+> str x
	| [x] -> str x
	| x::xs -> pprint_path xs <+> str "." <+> str x
	| [] -> empty

let rec tempexp_for_tmpdef key tmpdefs (m,p,(m2,e)) = match p,e with
	| ["_tag"],Var tag -> ti m (TConApp (tag, 
				tempexp_for_path_maybe m2 key tmpdefs [id_str tag;"_body"]))
	| [], e when is_call_to "GC_malloc" e -> ti m (TAlloc (New, 
				tempexp_for_path m key tmpdefs ["*"]))
	| [], e -> Some (m,e)
	| _ -> None

and structfield_for_tmpdef key tmpdefs (m,p,e) = match p with
	| [name] -> Some (name,tempexp_for_path m key tmpdefs [name])
	| _ -> None

and tempstruct_for_tmpdefs m key tmpdefs = 
	match option_map (structfield_for_tmpdef key tmpdefs) tmpdefs with
		| [] -> None
		| fields -> ti m (TStruct fields)
		
and tempexp_for_path_maybe m key tmpdefs path = 
	let forpath = option_map (tmpdef_for_path path) tmpdefs in
	List.iter (fun (m,_,_) -> add_tag m (MaybeHidden key)) forpath;
	match option_map (tempexp_for_tmpdef key forpath) forpath with
		| [x] -> Some x
		| [] -> tempstruct_for_tmpdefs m key forpath
		| _ -> cantdecode m 
			(str "Jekyll temporary is multiply initialised for path:" <++>
				pprint_path path) 

and tempexp_for_path m key tmpdefs path = 
	match tempexp_for_path_maybe m key tmpdefs path with
		| Some x -> x
		| None -> cantdecode m 
			(str "Jekyll temporary is not initialised for path:" <++>
				pprint_path path)
		
let temps : (key * expression) SH.t = SH.create ()	 
	 
let decode_tempinit_decl tmpdefs (m,_ as decl) = 
	if is_simpledecl decl then
		let name = simpledecl_name decl in
		let key = new_key () in
		try 
			SH.add (id_str name) 
				(key,tempexp_for_path m key tmpdefs [id_str name]) temps	 
		with
			_ -> ()
	 
let insert_temp (m,expr) = match expr with
	| Var id when SH.mem (id_str id) definers && SH.mem (id_str id) temps && not (has_tag m AlreadyExpanded) ->
		let key,texp = SH.find (id_str id) temps in
		add_tag m AlreadyExpanded;
		add_tag (SH.find (id_str id) definers) (MaybeHidden key); 
		cpm m,JklNonDet (key,[texp;(m,expr)],(m,expr))
	| _ -> m,expr
	 
let decode_tempinits (decls,stmts : block) = 
	List.iter add_definer decls;
	let tmpdefs = option_map tmpdef_for_stmt stmts in
	List.iter (decode_tempinit_decl tmpdefs) decls;
	decls,stmts
	 	 

(* ------------------------------------------------------------
 * Dict env temp/lambda env temps <-> _dt/_ft prefix
 * ------------------------------------------------------------ *)
	 
let is_normdecl decl = 
	if is_simpledecl decl then
		let name = simpledecl_name decl in
		not (is_funtemp_name name || is_dicttemp_name name)
	else
		true
	 
let decode_tempenvs (decls,stmts : block) =
	(List.filter is_normdecl decls, stmts)


(* ------------------------------------------------------------
 * Environment args <-> _de/_dm/_cenv/_ft prefix 
 * ------------------------------------------------------------ *)

let rec is_real_exp (m,expr) = match expr with
	| Var id -> not (is_env_name id)
	| Field (_,_,id) -> not (is_env_name id)
	| FunCall ((_,Var(_,"&")),[e]) -> is_real_exp e
	| _ -> true
	
let is_real_decl decl = 
	not (is_simpledecl decl && is_env_name (simpledecl_name decl))
	
let rec filter_argexps exps = match exps with
	| (m,Cast(_,(_,Var id)))::(_,Var id_e)::xs 
		when is_funenv_name id_e
		-> (m,Var id)::filter_argexps xs
	| x::xs -> if is_real_exp x then 
				x::filter_argexps xs else filter_argexps xs
	| [] -> []
	
let decode_envargs_funcall (m,expr) = match expr with
	| FunCall (e,es) -> m,FunCall (e,filter_argexps es)
	| _ -> m,expr
		
let decode_tycontext_arg (m,decl) = match decl with
	| _,[(_,Some id),None] -> split_tycontext_name id
	| _ -> None

let decode_envparams details = match details with
	| ArgsFull (m,ifaces,decls) -> 
		ArgsFull (m,ifaces @ option_map decode_tycontext_arg decls,
			List.filter is_real_decl decls)
	| _ -> details
	

(* ------------------------------------------------------------
 * Lambda Expression <-> function with name prefix _ff
 * ------------------------------------------------------------ *)

let lambdas : expression SH.t = SH.create ()

let arg_simpldecls m args = match args with
	| ArgsFull (_,[],decls) -> List.filter is_real_decl decls
	| _ -> mfatal m "Lambda expressions must have full arg declarations"

let take_lambda (m,extdecl) = match extdecl with 
	| Func (decl,krs,block) ->
		let id, _,args,_ = split_fundecl decl in
		if is_lambda_name id then 
			(SH.add (id_str id) 
				(m,LocalFun (arg_simpldecls m args,block)) lambdas; false)
		else true 
	| _ -> true
	
let take_lambdas p = List.filter take_lambda p 

let lamsubst = subst_expr_for_var lambdas is_lambda_name "lambda"

let insert_lambda expr = match_first expr [
	u_addr (u_subst_name lambdas is_lambda_name "lambda")
	]


(* ------------------------------------------------------------
 * Hidden struct <-> struct with an env prefix
 * ------------------------------------------------------------ *)

let is_real_struct (m,extdecl) = match extdecl with
	| Decl (_,((_,TyStruct (_,Some id,_)),_)) when is_env_name id -> false
	| _ -> true 

let remove_hidden_structs p = List.filter is_real_struct p 


(* ------------------------------------------------------------
 * Function type with _cenv argument <-> closure
 * ------------------------------------------------------------ *)

let is_lamenv decl = is_simpledecl decl && is_thisenv_name (simpledecl_name decl)

let decode_closure_type dmod = match dmod with
	| DFun (ArgsFull (m,ifaces,d::decls),SimpleFun) when is_lamenv d -> 
		DFun (ArgsFull (m,ifaces,decls),Closure)
	| _ -> dmod


(* ------------------------------------------------------------
 * TypeClass Implemenation <-> struct followed by functions
 * ------------------------------------------------------------ *)

	
let maybe_simpledecl decl = 
	if is_simpledecl decl then
		simpledecl_typ decl, simpledecl_name decl, simpledecl_init decl
	else raise CantDecode	
	
let maybe_struct ctyp = match ctyp with
	| TyStruct (_,Some id,details) -> id,details
	| _ -> raise CantDecode
	
let maybe_decl (m,extdecl) = match extdecl with
	| Decl decl -> decl
	| _ -> raise CantDecode	
	
let maybe_impldecl extdecl =
	let decl = maybe_decl extdecl in
	let (_,ctyp,_),dictid,init = maybe_simpledecl decl in
	let ifaceid,details = maybe_struct ctyp in
	let isproto = (init = None) in
	ifaceid,dictid,isproto
	
let dictid_tyname ifaceid dictid = 
	try	Scanf.sscanf (id_str dictid) "_dm%[^'_']_%[^'_']"
		(fun ifacestr tystr ->
				if ifacestr = id_str ifaceid then tystr
				else raise CantDecode)
	with _ -> raise CantDecode
	
let is_typemethod tyname id = 
	try Scanf.sscanf (id_str id) "%[^['_']_%[^'_']"
		(fun tystr methstr -> if tyname = tystr then true else false)
	with _ -> false	
	
let rec consume_methods tyname isproto xs = match xs with
	| (_,Decl decl as extd)::xs when
				is_simpledecl decl && 
				is_typemethod tyname (simpledecl_name decl) && 
				isproto ->
			let methods,fdefs,others = consume_methods tyname isproto xs in
			extd::methods,fdefs,others
	| (m,Func (decl,_,_ as fundef) as extd)::xs when 
				is_simpledecl decl && 
				is_typemethod tyname (simpledecl_name decl) &&
				not isproto ->
			let methods,fdefs,others = consume_methods tyname isproto xs in
			extd::methods,(cpm m,fundef)::fdefs,others
 | (m,Decl (_,((_,TyStruct (_,Some name,_)),_)) as extd)::xs when 
				is_dictenv_name name ->
			let methods,fdefs,others = consume_methods tyname isproto xs
			in
			extd::methods,fdefs,others
	| x::xs -> [],[],x::xs
	| [] -> [],[],[]

let impl_match tyname ifaceid (typ,iface,body) =
	id_str (typ_name_nocheck typ) = tyname && id_equal ifaceid iface
				
let make_impl m isproto fdefs (typ,ifacespec,tyiface) =
	if isproto then	cpm m, Dict(typ,ifacespec,DictProto tyiface)
	else cpm m, Dict (typ,ifacespec,DictImpl (tyiface,fdefs))
	
let decode_impl (m,_ as extdecl) xs = 
	let isimpl = new_key () in
	let ifaceid,dictid,isproto = maybe_impldecl extdecl in
	let tyname = dictid_tyname ifaceid dictid in
	let methods,fdefs,others = consume_methods tyname isproto xs in
	List.iter (fun (m,_) -> add_tag m (MaybeHidden isimpl)) methods;
	let guesses = get_possible_impls () in
	let possibles = List.filter (impl_match tyname ifaceid) guesses in
	let possimpls = List.map (make_impl m isproto fdefs) possibles in
	let which = new_key () in
	let implopts = cpm m,NonDet (which,possimpls,extdecl) in
	(cpm m,NonDet (isimpl,[implopts;extdecl],extdecl)),methods,others		

let rec decode_impls program = match program with
	| extdecl::xs -> begin
		try
			let extdecl,methods,others = decode_impl extdecl xs in
			[extdecl] @ methods @ decode_impls others
		with
			CantDecode -> extdecl::decode_impls xs
		end
	| [] -> []


	
(* ------------------------------------------------------------
 * All operations combined into one pass
 * ------------------------------------------------------------ *)
					
let d = mp_default		

let decode_extd mp (m,_ as extd) = 
	if get_lang m <> PureC then	
		d.m_extd mp extd
	else extd 
			
let rec fseq funcs x = match funcs with
	| [] -> x
	| f::fs -> fseq	fs (f x)
	
let seq funcs dflt mp x = dflt mp (fseq funcs x)
let wseq before dflt after mp x = fseq after (seq before dflt mp x)			
			
let mp_decode = {d with
	m_expr = seq [insert_temp;decode_closure_vars;insert_lambda;
					decode_envargs_funcall; decode_method_call] 
					d.m_expr;
	m_ctyp = seq [fill_in_struct_typarams] d.m_ctyp;
	m_dmod = seq [decode_closure_type] d.m_dmod;
	m_extd = decode_extd;
	m_idef = seq [find_iface_methods] d.m_idef;
	m_stmt = seq [decode_tagged_switch;decode_conpat] d.m_stmt;
	m_id = decode_method_name;
	m_pgrm = wseq [take_lambdas] d.m_pgrm [remove_hidden_structs];
	m_ddef = wseq [] d.m_ddef [set_dict_method_kinds];
	m_blck = wseq [decode_tempinits;decode_tempenvs] d.m_blck
					 [remove_hiddens];
	m_ardt = seq [decode_envparams] d.m_ardt;	
}				

let decode_features program oldversion =
	set_old_program oldversion; 
	map_program mp_decode (decode_impls program)

let mp_jkl_fixup = {d with
	m_ddef = (fun mp d -> set_dict_method_kinds d)
}
	
let fixup_jekyll program = map_program mp_jkl_fixup program
