import io, os, sys, re
from glob import glob
from collections import defaultdict
from argparse import ArgumentParser

def make_plain(conllu):
    tok_num = 1
    output = []
    for line in conllu.split("\n"):
        if len(line.strip()) == 0:
            continue
        if "\t" in line:
            fields = line.split("\t")
            if "-" in fields[0] or "." in fields[0]:
                continue
            if "BeginSeg=Yes" in fields[-1]:
                misc = "BeginSeg=Yes"
            else:
                misc = "_"
            line = "\t".join([str(tok_num),fields[1],"_","_","_","_","_","_","_",misc])
            tok_num +=1
        if "# newdoc" in line:
            tok_num = 1
        elif line.startswith("#"):
            continue
        output.append(line)
    output = "\n".join(output)
    output = output.replace("# newdoc","\n# newdoc").strip() + "\n\n"
    return output


outmode = "standoff"
ellipsis_marker = "<*>"
if outmode == "standoff":
    header = ["doc", "unit1_toks", "unit2_toks", "unit1_txt", "unit2_txt", "s1_toks","s2_toks","unit1_sent","unit2_sent","dir", "orig_label","label"]
else:
    header = ["doc","start_toks","pre","arg1","mid","arg2","post","dir","label"]

#rel_mapping = ['label', 'background', 'e-elaboration', 'interpretation', 'joint', 'circumstance', 'evaluation-n', 'elaboration', 'reason', 'evaluation-s', 'evidence', 'solutionhood', 'concession', 'purpose', 'cause', 'means', 'list', 'preparation', 'condition', 'conjunction', 'antithesis', 'contrast', 'summary', 'restatement']
rel_mapping = defaultdict(dict)
rel_mapping["deu.rst.pcc"] = {"motivation":"reason", "unless":"antithesis", "justify":"reason", "enablement":"background", "otherwise":"antithesis"}
rel_mapping["eng.rst.rstdt"] = {"attribution":"attribution","attribution-e":"attribution","attribution-n":"attribution","attribution-negative":"attribution","background":"background","background-e":"background","circumstance":"background","circumstance-e":"background","cause":"cause","cause-result":"cause","result":"cause","result-e":"cause","consequence":"cause","consequence-n-e":"cause","consequence-n":"cause","consequence-s-e":"cause","consequence-s":"cause","comparison":"comparison","comparison-e":"comparison","preference":"comparison","preference-e":"comparison","analogy":"comparison","analogy-e":"comparison","proportion":"comparison","condition":"condition","condition-e":"condition","hypothetical":"condition","contingency":"condition","otherwise":"condition","contrast":"contrast","concession":"contrast","concession-e":"contrast","antithesis":"contrast","antithesis-e":"contrast","elaboration-additional":"elaboration","elaboration-additional-e":"elaboration","elaboration-general-specific-e":"elaboration","elaboration-general-specific":"elaboration","elaboration-part-whole":"elaboration","elaboration-part-whole-e":"elaboration","elaboration-process-step":"elaboration","elaboration-process-step-e":"elaboration","elaboration-object-attribute-e":"elaboration","elaboration-object-attribute":"elaboration","elaboration-set-member":"elaboration","elaboration-set-member-e":"elaboration","example":"elaboration","example-e":"elaboration","definition":"elaboration","definition-e":"elaboration","purpose":"enablement","purpose-e":"enablement","enablement":"enablement","enablement-e":"enablement","evaluation":"evaluation","evaluation-n":"evaluation","evaluation-s-e":"evaluation","evaluation-s":"evaluation","interpretation-n":"evaluation","interpretation-s-e":"evaluation","interpretation-s":"evaluation","interpretation":"evaluation","conclusion":"evaluation","comment":"evaluation","comment-e":"evaluation","evidence":"explanation","evidence-e":"explanation","explanation-argumentative":"explanation","explanation-argumentative-e":"explanation","reason":"explanation","reason-e":"explanation","list":"joint","disjunction":"joint","manner":"manner-means","manner-e":"manner-means","means":"manner-means","means-e":"manner-means","problem-solution":"topic-comment","problem-solution-n":"topic-comment","problem-solution-s":"topic-comment","question-answer":"topic-comment","question-answer-n":"topic-comment","question-answer-s":"topic-comment","statement-response":"topic-comment","statement-response-n":"topic-comment","statement-response-s":"topic-comment","topic-comment":"topic-comment","comment-topic":"topic-comment","rhetorical-question":"topic-comment","summary":"summary","summary-n":"summary","summary-s":"summary","restatement":"summary","restatement-e":"summary","temporal-before":"temporal","temporal-before-e":"temporal","temporal-after":"temporal","temporal-after-e":"temporal","temporal-same-time":"temporal","temporal-same-time-e":"temporal","sequence":"temporal","inverted-sequence":"temporal","topic-shift":"topic-change","topic-drift":"topic-change","textualorganization":"textual-organization"}
rel_mapping["eus.rst.ert"] = {"aukera":"otherwise","antitesia":"anthitesis","ahalbideratzea":"enablement","kausa":"cause","zirkunstantzia":"circumstance","kontzesioa":"concession","baldintza":"condition","konjuntzioa":"conjunction","kontrastea":"contrast","disjuntzioa":"disjunction","elaborazioa":"elaboration","ebaluazioa":"evaluation","ebidentzia":"evidence","testuingurua":"background","interpretazioa":"interpretation","justifikazioa":"justify","lista":"list","metodoa":"means","motibazioa":"motibation","prestatzea":"preparation","helburua":"purpose","birformulazioa":"restatement","ondorioa":"result","laburpena":"summary","sekuentzia":"sequence","arazo-soluzioa":"solution-hood","bateratzea":"joint","alderantzizko-baldintza":"unless","ez-baldintzatzailea":"unconditional","birformulazioa-nn":"restatement","definitu-gabeko erlazioa":"undefined"}
rel_mapping["fas.rst.prstc"] = {"comparisonmult":"comparison","temporalmult":"temporal"}

gum_dev = ["GUM_interview_cyclone", "GUM_interview_gaming", "GUM_news_iodine", "GUM_news_homeopathic", "GUM_voyage_athens", "GUM_voyage_coron", "GUM_whow_joke", "GUM_whow_overalls", "GUM_bio_byron", "GUM_bio_emperor", "GUM_fiction_lunre", "GUM_fiction_beast", "GUM_academic_exposure", "GUM_academic_librarians", "GUM_speech_impeachment", "GUM_textbook_cognition", "GUM_vlog_radiology", "GUM_conversation_grounded","GUM_reddit_macroeconomics","GUM_reddit_pandas"]
gum_test = ["GUM_interview_libertarian", "GUM_interview_hill", "GUM_news_nasa", "GUM_news_sensitive", "GUM_voyage_oakland", "GUM_voyage_vavau", "GUM_whow_mice", "GUM_whow_cactus", "GUM_fiction_falling", "GUM_fiction_teeth", "GUM_bio_jespersen", "GUM_bio_dvorak", "GUM_academic_eegimaa", "GUM_academic_discrimination", "GUM_speech_austria", "GUM_textbook_chemistry", "GUM_vlog_studying", "GUM_conversation_retirement","GUM_reddit_escape","GUM_reddit_monsters"]
pcc_dev_set = ["maz-5144","maz-5297","maz-5701","maz-5709","maz-5715","maz-5873","maz-5876","maz-5932","maz-6046","maz-6159","maz-6165","maz-6193","maz-6488","maz-6539","maz-6728","maz-6918","maz-6993"]
pcc_test_set = ["maz-3367","maz-3377","maz-3415","maz-3547","maz-4031","maz-4181","maz-4282","maz-4403","maz-4428","maz-4472","maz-4636","maz-4794","maz-4959","maz-5007","maz-5010","maz-5012","maz-5039"]
spanish_dev = ["me00025","me00026","me00027","me00028","ps00024b","ps00025b","ps00026b","ps00027b","ma00063b","ma00064b","ma00065b","ma00066b","ma00067b","ma00068b","ma00069b","ma00070b","de00002","ec00013","ec00014","li00037","li00038","li00039","li00040","li00041","as00014","as00015","in00002","se00032b","se00033b","se00034","se00035","se00036"]
spanish_test = ["me00029","me00030","me00031","me00032","ps00028b","ps00029b","ps00030b","ps00031b","ma00071b","ma00072b","ma00073b","ma00074b","ma00075b","ma00076b","ma00077a","ma00078b","de00003","ec00015","ec00016","li00042","li00043","li00044","li00045","li00046","as00016","as00017","in00003","se00037","se00038","se00039","se00040","se00041"]
annodis_dev = ["wik2_selectionNaturelle_selection","ling_fuchs_section2","geop_3_space","wik1_26_02-04-2006","wik1_27_02-04-2006","wik1_28_02-04-2006","wik1_29_02-04-2006","news_32_28-01-2003","news_33_21-01-2003","news_34_10-06-2002","news_35_03-11-2002"]
annodis_test = ["wik2_vinDeChampagne_vinification","ling_leon_contenuDinformation","geop_3_spatiaux","wik1_30_02-04-2006","wik1_31_02-04-2006","wik1_32_02-04-2006","wik1_33_02-04-2006","news_36_24-08-1999","news_37_12-06-1999","news_38_06-07-1999","news_39_04-05-2002"]
stac_dev = ["s2-practice3","pilot20","s2-league8-game2","s2-leagueM-game2","s1-league3-game4","s1-league3-game5"]
stac_test = ["s2-practice4","pilot21","s2-leagueM-game3","s2-leagueM-game5","s1-league3-game6","s1-league3-game7"]
rstdt_test = ["wsj_0602","wsj_0607","wsj_0616","wsj_0623","wsj_0627","wsj_0632","wsj_0644","wsj_0654","wsj_0655","wsj_0667","wsj_0684","wsj_0689","wsj_1113","wsj_1126","wsj_1129","wsj_1142","wsj_1146","wsj_1148","wsj_1169","wsj_1183","wsj_1189","wsj_1197","wsj_1306","wsj_1307","wsj_1325","wsj_1331","wsj_1346","wsj_1354","wsj_1365","wsj_1376","wsj_1380","wsj_1387","wsj_2336","wsj_2354","wsj_2373","wsj_2375","wsj_2385","wsj_2386"]
rstdt_dev = ["wsj_0605","wsj_0609","wsj_0618","wsj_0624","wsj_0629","wsj_0636","wsj_0641","wsj_0657","wsj_0658","wsj_0663","wsj_0675","wsj_0677","wsj_1109","wsj_1127","wsj_1128","wsj_1133","wsj_1135","wsj_1136","wsj_1164","wsj_1175","wsj_1181","wsj_1193","wsj_1308","wsj_1311","wsj_1327","wsj_1333","wsj_1348","wsj_1358","wsj_1362","wsj_1370","wsj_1384","wsj_1386","wsj_2391","wsj_2393","wsj_2394","wsj_2395","wsj_2396","wsj_2398"]
sctb_spa_dev = ["EEP_ESP7-GS","FCEC_ESP1-GS","BMCS_ESP4-GS","CCICE_ESP4-GS","FICB_ESP4-GS","ICP_ESP5-GS","ICP_ESP6-GS","TERM39_ESP-GS","TERM40_ESP-GS"]
sctb_spa_test = ["EEP_ESP8-GS","FCEC_ESP2-GS","BMCS_ESP5-GS","CCICE_ESP5-GS","FICB_ESP5-GS","ICP_ESP7-GS","ICP_ESP8-GS","TERM50_ESP-GS","TERM51_ESP-GS"]
rusrt_dev2019 = ["comp_49","comp_5","comp_50","comp_51","comp_52","news_67","news_68","news_69","news_7","news_70","news_71","news_72","news_73","ling_50","ling_51","ling_52","ling_53","ling_54","ling_55"]
rusrt_test2019 = ["comp_53","comp_54","comp_6","comp_7","comp_8","news_74","news_75","news_76","news_77","news_78","news_79","news_8","news_9","ling_57","ling_58","ling_6","ling_7","ling_8","ling_9"]
rusrt_dev = ["sci.ling_52","sci.ling_55","sci.ling_51","news1_69","sci.comp_51","news1_72","news1_70","news1_68","sci.comp_49","news1_7","sci.comp_52","sci.comp_5","news1_67","sci.comp_50","sci.ling_54","news1_71","news1_73","sci.ling_50","sci.ling_53","blogs_80","blogs_51","blogs_26","blogs_37","blogs_94","blogs_41","blogs_62","blogs_25","blogs_5","blogs_77","blogs_79"]
rusrt_test = ["sci.comp_8","news1_78","sci.ling_57","news1_77","news1_9","sci.ling_6","sci.ling_8","sci.comp_7","news1_76","news1_79","sci.ling_7","sci.comp_53","sci.comp_54","news1_74","sci.ling_9","news1_8","news1_75","sci.comp_6","sci.ling_58","blogs_100","blogs_97","blogs_17","blogs_48","blogs_19","blogs_2","blogs_45","blogs_52","blogs_7","blogs_16","blogs_57"]
sctb_zho_dev = ["BMCS_CHN4-GS","FICB_CHN4-GS","EEP_CHN7-GS","TERM39_CHN-GS","TERM40_CHN-GS","CCICE_CHN4-GS","ICP_CHN5-GS","ICP_CHN6-GS","FCEC_CHN1-GS"]
sctb_zho_test = ["BMCS_CHN5-GS","FICB_CHN5-GS","EEP_CHN8-GS","TERM50_CHN-GS","TERM51_CHN-GS","CCICE_CHN5-GS","ICP_CHN7-GS","ICP_CHN8-GS","FCEC_CHN2-GS"]
nldt_dev = ["FL15_MS_research","FL16_Slachtofferhulp","FL17_Habitat","EE15_Amalthea","EE16_Epimetheus","EE17_Hyperion","AD15_AA_properties","AD16_Bifiene","AD17_BlueBand","PSN15_Aarde_maan_jonger","PSN16_Dwergplaneten","PSN17_Neptunus_at_een_planeet"]
nldt_test = ["FL18_Fonds_voor_het_hart","FL19_Terre_des_hommes","FL20_Artsen_zonder_grenzen","EE18_Orionnevel","EE19_Quasar","EE20_Meteoren","AD18_Nestle","AD19_Menzis","AD20_NRC_Focus","PSN18_Titan","PSN19_Aarde","PSN20_Kepler_exoplaneten"]
por_cstn_dev = ["D4_C36_JB","D1_C36_Folha","D3_C36_OGlobo","D2_C14_Estadao_21-08-2006_10h15","D4_C14_JB","D3_C14_OGlobo_21-08-2006_07h40","D2_C13_Estadao_07-08-2006_02h52","D3_C13_GPovo_07-08-2006_09h28","D1_C13_Folha_07-08-2010_07h21","D4_C25_JB","D1_C25_Folha_16-07-2007_07h50","D5_C25_GPovo_17-07-2007","D5_C22_GPovo_24-07-2007_08h51","D2_C22_Estadao_24-07-2007_08h58"]
por_cstn_test = ["D2_C37_GPovo_17-10-2007_13h24","D1_C37_OGlobo_17-10-2007_12h21","D3_C30_OGlobo_07-08-2007_09h10","D1_C30_Folha_07-08-2007_09h19","D2_C30_Estadao_07-08-2007_07h59","D4_C39_JB","D3_C39_OGlobo_07-08-2007_13h31","D2_C39_Estadao_07-08-2007_14h18","D1_C38_Folha_17-07-2007_12h15","D4_C38_JB","D2_C38_Estadao_17-07-2007_12h03","D1_C31_Folha_17-10-2007_09h39"]
basque_dev = ['GMB0001', 'GMB0201', 'GMB0801', 'INF13', 'INF18', 'INF20', 'LAB12', 'LAB23', 'MIS6', 'MIS9', 'OSA08', 'OSA16', 'OSA20', 'SENTARG04', 'SENTARG07', 'SENTBER03', 'SENTQUE01', 'TERM21', 'TERM22',  'TERM29','TERM38', 'ZTF15', 'ZTF19', 'ZTF4']
basque_test = ['GMB0502', 'GMB0703', 'INF06', 'INF11', 'LAB21', 'LAB22', 'LAB81', 'MIS1', 'MIS12', 'MIS13', 'OSA01', 'OSA03', 'OSA13', 'OSA17', 'SENTARG06', 'SENTBER06', 'SENTPUT01', 'TERM17', 'TERM25', 'TERM28', 'TERM31', 'ZTF1', 'ZTF10', 'ZTF5']
persian_dev = ["etemad001","etemad039","etemad040","etemad041","meidan001","meidan003","meidan005","meidan015","meidan025","meidan035","shargh001","shargh022","shargh025","shargh026","shargh033"]
persian_test = ["etemad007","etemad008","etemad009","etemad023","etemad024","etemad027","etemad029","etemad036","etemad056","etemad058","etemad060","meidan002","meidan010","meidan048","shargh031"]

def get_rsd(dir_path,chars2toks, toks_by_doc, conll_data, add_missing_tokens=False):
    output = {}
    files = glob(dir_path + "*.rsd")
    for file_ in files:
        if "EE16" in file_:
            pass
        tokenized = []
        text = io.open(file_,encoding="utf8").read()
        text = text.replace("​","")
        prev_tok = 0
        char_num = 0
        docname = os.path.basename(file_).replace(".rsd","").replace(".rst","")

        if add_missing_tokens:
            edu_map = defaultdict(list)
            edu_num = 0
            for line in conll_data[docname].strip().split("\n"):
                if "\t" in line:
                    fields = line.split("\t")
                    if "-" in fields[0] or "." in fields[0]:
                        continue
                    if "BeginSeg" in line:
                        edu_num += 1
                    edu_map[edu_num].append(fields[1])

            fixed = []
            for rsd_row in text.strip().split("\n"):
                fields = rsd_row.split("\t")
                fields[1] = " ".join(edu_map[int(fields[0])]).strip()
                fixed.append("\t".join(fields))
            output[docname] = "\n".join(fixed) + "\n"
            continue
        else:
            for line in text.split("\n"):
                if "\t" in line:
                    fields = line.split("\t")
                    raw_content_chars = re.sub(r'\s','',fields[1])
                    this_edu = []
                    if "inside the Hammacks" in fields[1]:
                        d=4
                    for i,c in enumerate(raw_content_chars):
                        if chars2toks[docname][char_num] != prev_tok:
                            this_edu.append(" ")
                            prev_tok += 1
                        this_edu.append(c)
                        char_num +=1
                    fields[1] = "".join(this_edu).strip()
                    line = "\t".join(fields)
                if len(line) > 0:
                    tokenized.append(line)

        output[docname] = "\n".join(tokenized) + "\n"
    return output

def get_conll(dir_path):
    output = {}
    chars2toks = defaultdict(dict)
    toks = defaultdict(list)
    files = glob(dir_path + "*.conllu")
    for file_ in files:
        text = io.open(file_,encoding="utf8").read().replace("​","")  # Remove invisible space
        parts = text.split("# newdoc")
        for i, part in enumerate(parts):
            if i == 0:
                continue
            part = part.strip()
            docname = re.search(r'^id ?= ?([^\s]+)',part).group(1)
            output[docname] = "# newdoc " + part.strip()
            tok_num = 0
            offset = 0
            for line in part.split("\n"):
                if "\t" in line:
                    fields = line.split("\t")
                    if "." not in fields[0] and "-" not in fields[0]:
                        for i, c in enumerate(fields[1]):
                            chars2toks[docname][i+offset] = tok_num
                        offset += len(fields[1])
                        toks[docname].append(fields[1])
                        tok_num += 1
    return output, chars2toks, toks


def format_range(tok_ids):
    # Takes a list of IDs and returns formatted string:
    # contiguous subranges of numbers are separated by '-', e.g. 5-24
    # discontinuous subranges are separated by ',', e.g. 2,5-24
    def format_subrange(subrange):
        if len(subrange) == 1:
            return str(subrange[0]+1)
        else:
            return str(min(subrange)+1) + "-" + str(max(subrange)+1)

    subranges = [[]]
    last = None
    for tid in sorted(tok_ids):
        if last is None:
            subranges[-1].append(tid)
        elif tid == last +1:
            subranges[-1].append(tid)
        else:
            subranges.append([tid])
        last = tid

    formatted = []
    for subrange in subranges:
        formatted.append(format_subrange(subrange))

    return ",".join(formatted)


def format_text(arg1_toks, toks):
    last = arg1_toks[0] - 1
    output = []
    for tid in sorted(arg1_toks):
        if tid != last + 1:
            output.append(ellipsis_marker)
        output.append(toks[tid])
        last = tid
    return " ".join(output)


def format_sent(arg1_sid, sents):
    sent = sents[arg1_sid]
    lines = sent.split("\n")
    output = []
    for line in lines:
        if "\t" in line:
            fields = line.split("\t")
            if "." in fields[0] or "-" in fields[0]:  # supertok or ellipsis token
                continue
            output.append(fields[1])
    return " ".join(output)


def make_rels(rsd_data, conll_data, corpus="eng.rst.gum"):
    err_docs = set()
    dev = ["\t".join(header)]
    test = ["\t".join(header)]
    train = ["\t".join(header)]

    for i, docname in enumerate(rsd_data):
        sent_map = {}
        toks = {}
        sents = conll_data[docname].split("\n\n")

        snum = 0
        toknum = 0
        s_starts = {}
        s_ends = {}

        for sent in sents:
            lines = sent.split("\n")
            for line in lines:
                if "\t" in line:
                    fields = line.split("\t")
                    if "-" in fields[0] or "." in fields[0]:
                        continue
                    if fields[0] == "1":
                        s_starts[snum] = toknum
                    sent_map[toknum] = snum
                    toks[toknum] = fields[1]
                    toknum += 1
            s_ends[snum] = toknum - 1
            snum += 1

        rsd_lines = rsd_data[docname].split("\n")

        parents = {}
        texts = {}
        tok_map = {}
        offset = 0
        rels = {}
        for line in rsd_lines:
            if "\t" in line:
                fields = line.split("\t")
                edu_id = fields[0]
                edu_parent = fields[6]
                relname = fields[7].replace("_m","").replace("_r","")
                text = fields[1].strip()
                texts[edu_id] = text
                tok_map[edu_id] = (offset, offset + len(text.split())-1)
                offset += len(text.split())
                if edu_parent == "0":  # Ignore root
                    continue
                parents[edu_id] = edu_parent
                rels[edu_id] = relname

        same_unit_components = defaultdict(set)
        same_unit_data = {}
        # set up same-unit storage
        for edu_id in parents:
            if rels[edu_id].lower().startswith("same"):
                # collect all intervening text inside same-unit children
                parent = parents[edu_id]
                start = int(parent)
                end = int(edu_id)
                unit_ids = [str(x) for x in range(start,end+1)]
                same_unit_components[parent].add(edu_id)
                if parent not in same_unit_data:
                    same_unit_data[parent] = (start,end," ".join([texts[t].strip() for t in unit_ids]))
                else:
                    start, end, text = same_unit_data[parent]
                    if int(edu_id) > start:  # This is a subsequent same-unit member on the right
                        unit_ids = [str(x) for x in range(end+1,int(edu_id)+1)]
                        more_text = " ".join([texts[t].strip() for t in unit_ids])
                        same_unit_data[parent] = (start,int(edu_id)," ".join([text,more_text]))
                    else:
                        raise IOError("LTR same unit!\n")

        output = []
        for edu_id in parents:
            if rels[edu_id].lower().startswith("same"):
                continue  # Skip the actual same-unit relation
            child_text = texts[edu_id]
            parent_id = parents[edu_id]
            if int(edu_id) < int(parent_id):
                direction = "1>2"
                arg1_start, arg1_end = tok_map[edu_id]
                arg2_start, arg2_end = tok_map[parent_id]
            else:
                direction = "1<2"
                arg1_start, arg1_end = tok_map[parent_id]
                arg2_start, arg2_end = tok_map[edu_id]

            parent_text = texts[parent_id]
            if parent_id in same_unit_data:
                start, end, text = same_unit_data[parent_id]
                if int(edu_id) < start or int(edu_id)> end:
                    parent_text = text
                    if int(edu_id) < int(parent_id):
                        arg2_start, _ = tok_map[str(start)]
                        _, arg2_end = tok_map[str(end)]
                    else:
                        arg1_start, _ = tok_map[str(start)]
                        _, arg1_end = tok_map[str(end)]

            if edu_id in same_unit_data:
                start, end, text = same_unit_data[edu_id]
                if int(parent_id) < start or int(parent_id)> end:
                    child_text = text
                    if int(edu_id) < int(parent_id):
                        arg1_start, _ = tok_map[str(start)]
                        _, arg1_end = tok_map[str(end)]
                    else:
                        arg2_start, _ = tok_map[str(start)]
                        _, arg2_end = tok_map[str(end)]

            arg1_sid = sent_map[arg1_start]
            arg2_sid = sent_map[arg2_start]

            s1_start = s_starts[arg1_sid]
            s1_end = s_ends[arg1_sid]
            s2_start = s_starts[arg2_sid]
            s2_end = s_ends[arg2_sid]

            pre = []
            pre_toks = []
            arg1 = []
            arg1_toks = []
            mid = []
            mid_toks = []
            arg2 = []
            arg2_toks = []
            post = []
            post_toks = []
            for i in sorted(list(set(list(range(s1_start,s1_end+1)) + list(range(s2_start, s2_end+1))))):
                tok = toks[i]
                if i < arg1_start:
                    pre.append(tok)
                    pre_toks.append(i)
                elif i >= arg1_start and i <= arg1_end:
                    arg1.append(tok)
                    arg1_toks.append(i)
                elif i > arg1_end and i < arg2_start:
                    mid.append(tok)
                    mid_toks.append(i)
                elif i >= arg2_start and i <= arg2_end:
                    arg2.append(tok)
                    arg2_toks.append(i)
                else:
                    post.append(tok)
                    post_toks.append(i)

            if outmode == "standoff":
                if edu_id == "12":
                    a=3
                if parent_id == "12":
                    d=3
                comp1 = edu_id if int(edu_id) < int(parent_id) else parent_id
                comp2 = parent_id if int(edu_id) < int(parent_id) else edu_id
                # Reduce EDUs to minimal span in standoff mode
                arg1_toks = list(range(tok_map[comp1][0], tok_map[comp1][1]+1))
                arg2_toks = list(range(tok_map[comp2][0], tok_map[comp2][1]+1))
                # Add explicit discontinuous spans
                if comp1 in same_unit_components:
                    for component in same_unit_components[comp1]:
                        component_toks = list(range(tok_map[component][0], tok_map[component][1]+1))
                        arg1_toks += component_toks
                if comp2 in same_unit_components:
                    for component in same_unit_components[comp2]:
                        component_toks = list(range(tok_map[component][0], tok_map[component][1]+1))
                        arg2_toks += component_toks
                arg1_txt = format_text(arg1_toks,toks)
                arg1_sent = format_sent(arg1_sid,sents)
                arg2_txt = format_text(arg2_toks,toks)
                arg2_sent = format_sent(arg2_sid,sents)
                arg1_toks = format_range(arg1_toks)
                arg2_toks = format_range(arg2_toks)
                s1_toks = format_range(list(range(s1_start,s1_end+1)))
                s2_toks = format_range(list(range(s2_start,s2_end+1)))

                mapped_rel = rels[edu_id]
                if corpus in rel_mapping:
                    if mapped_rel in rel_mapping[corpus]:
                        mapped_rel = rel_mapping[corpus][mapped_rel]
                    elif mapped_rel.lower() in rel_mapping[corpus]:
                        mapped_rel = rel_mapping[corpus][mapped_rel.lower()]
                    else:
                        if mapped_rel!="ROOT":
                            #mapped_rel = mapped_rel.lower()
                            raise IOError("no rel map "+mapped_rel)
                        elif mapped_rel == "ROOT":
                            raise IOError("found ROOT entry in " +corpus + ": "+docname)
                if corpus.startswith("fas."):
                    mapped_rel = mapped_rel.lower()
                output.append("\t".join([docname,arg1_toks,arg2_toks,arg1_txt,arg2_txt,s1_toks,s2_toks,arg1_sent,arg2_sent,direction,rels[edu_id],mapped_rel]))
            else:
                pre = " ".join(pre) if len(pre) > 0 else "NULL"
                pre_toks = str(min(pre_toks)) if len(pre_toks) > 0 else "NA"
                arg1 = " ".join(arg1)
                arg1_toks = str(min(arg1_toks))
                mid = " ".join(mid) if len(mid) > 0 else "NULL"
                mid_toks = str(min(mid_toks)) if len(mid_toks) > 0 else "NA"
                arg2 = " ".join(arg2)
                arg2_toks = str(min(arg2_toks))
                post = " ".join(post) if len(post) > 0 else "NULL"
                post_toks = str(min(post_toks)) if len(post_toks) > 0 else "NA"

                indices = ";".join([pre_toks, arg1_toks, mid_toks, arg2_toks, post_toks])
                output.append("\t".join([docname,indices,pre,arg1,mid,arg2,post,direction,rels[edu_id]]))

        if docname in dev_set:
            dev += output
        elif docname in test_set:
            test += output
        else:
            train += output

    print("\n".join(sorted(list(err_docs))))
    dev = "\n".join(dev) + "\n"
    train = "\n".join(train) + "\n"
    test = "\n".join(test) + "\n"

    return dev, train, test

if __name__ == "__main__":

    p = ArgumentParser()
    p.add_argument("-c","--corpus",default="eng.rst.gum")
    p.add_argument("-p","--plain",action="store_true",help="also write plain .tok files")
    opts = p.parse_args()

    if opts.corpus == "PCC" or opts.corpus =="deu.rst.pcc":
        dev_set = pcc_dev_set
        test_set = pcc_test_set
    elif "spanish" in opts.corpus.lower() or opts.corpus == "spa.rst.rststb":
        dev_set = spanish_dev
        test_set = spanish_test
    elif "annodis" in opts.corpus.lower():
        dev_set = annodis_dev
        test_set = annodis_test
    elif "stac" == opts.corpus.lower():
        dev_set = stac_dev
        test_set = stac_test
    elif "rstdt" == opts.corpus.lower() or "rstdt" in opts.corpus:
        dev_set = rstdt_dev
        test_set = rstdt_test
    elif "sctb_spa" == opts.corpus.lower() or opts.corpus == "spa.rst.sctb":
        dev_set = sctb_spa_dev
        test_set = sctb_spa_test
    elif "sctb_zho" == opts.corpus.lower() or opts.corpus == "zho.rst.sctb":
        dev_set = sctb_zho_dev
        test_set = sctb_zho_test
    elif "rus." in opts.corpus.lower():
        dev_set = rusrt_dev
        test_set = rusrt_test
    elif "nldt" == opts.corpus.lower() or "nldt" in opts.corpus:
        dev_set = nldt_dev
        test_set = nldt_test
    elif opts.corpus.lower().startswith("eus"):
        dev_set = basque_dev
        test_set = basque_test
    elif opts.corpus.lower().startswith("por"):
        dev_set = por_cstn_dev
        test_set = por_cstn_test
    elif opts.corpus.startswith("fas"):
        dev_set = persian_dev
        test_set = persian_test
    else:
        dev_set = gum_dev
        test_set = gum_test

    corpus = opts.corpus
    corpus_dir = os.path.dirname(os.path.realpath(__file__)) + os.sep + "data" + os.sep + corpus + os.sep
    rsd_store = os.path.dirname(os.path.realpath(__file__)) + os.sep + "rsd_store" + os.sep

    conll_data, chars2toks, toks_by_doc = get_conll(corpus_dir)
    add_missing = True if corpus == "eng.rst.rstdt" else False
    rsd_data = get_rsd(rsd_store + corpus + os.sep, chars2toks, toks_by_doc, conll_data, add_missing_tokens=add_missing)

    dev, train, test = make_rels(rsd_data, conll_data, corpus=corpus)

    if opts.plain:
        plain_dev = ""
        plain_test = ""
        plain_train = ""
        for docname in conll_data:
            if docname in dev_set:
                plain_dev += make_plain(conll_data[docname].strip() + "\n\n")
            elif docname in test_set:
                plain_test += make_plain(conll_data[docname].strip() + "\n\n")
            else:
                plain_train += make_plain(conll_data[docname].strip() + "\n\n")
        with io.open(corpus_dir + corpus + "_dev.tok", 'w', encoding="utf8", newline="\n") as f:
            f.write(plain_dev)
        with io.open(corpus_dir + corpus + "_test.tok", 'w', encoding="utf8", newline="\n") as f:
            f.write(plain_test)
        with io.open(corpus_dir + corpus + "_train.tok", 'w', encoding="utf8", newline="\n") as f:
            f.write(plain_train)


    with io.open(corpus_dir + corpus + "_dev.rels",'w',encoding="utf8",newline="\n") as f:
        f.write(dev)
    with io.open(corpus_dir + corpus + "_test.rels",'w',encoding="utf8",newline="\n") as f:
        f.write(test)
    with io.open(corpus_dir + corpus + "_train.rels", 'w', encoding="utf8", newline="\n") as f:
        f.write(train)