#
# The script to illustrate the combination of a 12m+7m_TP line data sets.
#
# Using casa 5.1.1-5  (mainly to use newest version of tclean
#
# step 1:  Getting the three data sets
# step 2:  Getting listobs, plotants of data sets
# step 3:  Check wts of interferometric data
#      3.1 7m data
#      3.2 12m data
# step 4:  Get 7m, 12m mosaic coverage.  Define central positioon
#           
# step 5:  Check spectrum and frequency alignment of data
#      5.1:  TP spectrum from image
#      5.2:  7m avg visibility
#      5.3:  12m avg visibility
#      5.4:  7m dirty image peaks
#      5.5:  12m dirty image peaks
#          No offsets between 7m-12m data channels
#      5.6   Offset of 28 channels with total power channels
#
# step 6:  Make spectral line cube for 7m and 12m
# step 7:  Make continuum from other channels
#      7.1 For 7m
#      7.2 for 12m
# step 8:  concatenate data sets.  Equal wt and x10 7m weight
# step 9:  Try automatic box cleaning.
# step 10: Obtain the 51 TP line images that correspond to the interfer
# step 11: regrid of total power image on interferometric images
# step 12: feathering examples> chan 10, 30, 45
#
########################################################################
import numpy as np
import pylab as pl
cspeed = 299792.458
pcenter = 'J2000 16h09m18.1 -39d04m44.0'   #  Get this from step 4
#
D7 = 'Lup_7m.ms'     # Cut down calibration 7m data
D12 = 'Lup_12m.ms'   # Cut down calibration 12m data
TPI = 'Lup_TP.image' # Cut down TP image.
#
dostep = [12.1]
#
#  Steps 1 and 2 get the data/images and list and plot the
#
if 1 in dostep:
    print 'copying 7m data'
    os.system('cp -r ../adele_data/for_delivery/Lup_7m.ms .')
    print 'copying 12m data'
    os.system('cp -r ../adele_data/for_delivery/Lup_12m.ms .')
    print 'copying TP image'
    os.system('cp -r ../adele_data/for_delivery/Lup_TP.image .')
#
if 2 in dostep:
    print 'listobs and plotants for ms'
    os.system('rm -rf Lup_7m.listobs')
    listobs('Lup_7m.ms', listfile='Lup_7m.listobs')
    os.system('rm -rf Lup_12m.listobs')
    listobs('Lup_12m.ms', listfile='Lup_12m.listobs')
    os.system('rm -rf Lup_7m_plotants.png')
    plotants('Lup_7m.ms', figfile='Lup_7m_plotants.png')
    os.system('rm -rf Lup_12m_plotants.png')
    plotants('Lup_12m.ms', figfile='Lup_12m_plotants.png')

if 3.1 in dostep:
    print 'plot weights of 7m'
    plotms(vis='Lup_7m.ms',
           xaxis = 'uvdist',
           yaxis = 'wt',
           antenna = '*&*',
           spw = '',
           iteraxis = 'spw')
if 3.2 in dostep:
    print 'plot weights of 12m'
    plotms(vis='Lup_12m.ms',
           xaxis = 'uvdist',
           yaxis = 'wt',
           antenna = '*&*',
           spw = '',
           iteraxis = 'spw')
if 3.3 in dostep:
    print 'flag 7m spw 7,8'
    flagmanager(vis='Lup_7m.ms',
                mode = 'save',
                comment = 'before_spw_flag',
                versionname = 'spw_flag')
    flagdata(vis='Lup_7m.ms',
             mode = 'manual',
             spw = '7,8',
             flagbackup = False)
#
if 4 in dostep:
    print 'Make mosaic coverage'
    os.system('rm -rf mosaic_7.png')
    print 'make mosaic map plot'
    au.plotmosaic(vis = 'Lup_7m.ms',figfile = 'mosaic_7.png',
                  plotrange=[160,-160,-90,90])
    os.system('rm -rf mosaic_12.png')
    print 'make mosaic map plot'
    au.plotmosaic(vis = 'Lup_12m.ms',figfile = 'mosaic_12.png',
                  plotrange=[160,-160,-90,90])
    #
    os.system('eog mosaic*png')
#
if 5.1 in dostep:
    print 'get image peak spectrum of TP'
    peak = []
    flux = []
    kchan = []
    kvel = []
    kfreq = []
    pl.clf()
    lch = 148
    hch = 198
    for ich in range(lch,hch):
        a = imstat('Lup_TP.image', chans = str(ich))
        peak.append(a['max'][0])
        flux.append(a['flux'][0])
        temp = a['blcf']
        chan_freq = float(temp.split(',')[2][:-2])
        ref_freq = imhead('Lup_TP.image')['refval'][2]
        #kfreq.append(float(temp1))
        kchan.append(ich)
        temp2 = (ref_freq - chan_freq)/ref_freq * cspeed
        kvel.append(temp2)
        #print '%5d  %8.5f  %7.2f'%(ich,temp1/1.0E9, temp2)
    #os.system('rm TP_spectrum.png')
    clearplot()
    pl.clf()
    pl.plot(kchan,peak,'b-')
    #pl.plot(kchan,flux,'g-')
    pl.xlim(lch,hch)
    pl.title('TP Peak/integration FLux density',fontsize=18)
    pl.xlabel('Channel number ', fontsize=16)
    pl.ylabel('Flux Density (Jy)', fontsize=16)
    pl.savefig('peak_TP.png')
#
if 5.2 in dostep:
    print 'visibility spectrum of 7m data'
    os.system('rm -rf chan_amp_7m.png')
    plotms(vis = 'Lup_7m.ms',
           field = '',
           xaxis = 'channel',
           yaxis = 'amp',
           spw = '',
           antenna = '*&*',
           avgbaseline = True,
           avgtime = '3000',
           #iteraxis = 'spw',
           plotfile = 'chan_amp_7m.png',
           customsymbol = True, symbolsize = 8,
           avgscan = True)
if 5.3 in dostep:
    print 'visibility spectrum of 12m data'
    os.system('rm -rf chan_amp_12m.png')
    plotms(vis = 'Lup_12m.ms',
           #gridrows = 2,
           #gridcols = 1,
           #rowindex = 1,
           #colindex = 0,
           #plotindex = 1,
           field = '',
           xaxis = 'channel',
           yaxis = 'amp',
           uvrange = '0~40',
           antenna = '*&*',
           spw = '',
           avgbaseline = True,
           avgtime = '3000',
           avgspw = True,
           customsymbol = True, symbolsize = 8,
           plotfile = 'chan_amp_12m.png',
           avgscan = True)
#
if 5.4 in dostep:
    print 'make dirty images of 7-m to better check line emission'
    xpcenter = pcenter
    peak7 = []
    nich = []
    clearplot()
    for ich in range(120,170):
        str_ich = str(ich)
        Im = '7m_dirty_'+str_ich
        os.system('rm -rf '+Im+'*')
        print 'image channel ',str_ich
        tclean(vis = 'Lup_7m.ms',
               field = '',
               imagename = Im,
               spw = '*:'+str_ich,
               datacolumn = 'data',
               imsize = 256,
               deconvolver = 'hogbom',
               phasecenter = xpcenter,
               cell = '1.5arcsec',
               weighting = 'natural',
               interactive = False,
               niter = 0,
               restart = True,
               usemask = 'user')
        #
if 5.41 in dostep:
    peak7 = []
    nich = []
    clearplot()
    for ich in range(120,170):
        str_ich = str(ich)
        Im = '7m_dirty_'+str_ich
        peak7.append(imstat(Im+'.residual')['max'][0])
        nich.append(ich)
    pl.clf()
    pl.plot(nich,peak7,'b-')
    pl.title ('peak flux density in 7m image', fontsize=18)
    pl.xlim(120,170)
    pl.xlabel ('channel number', fontsize=16)
    pl.ylabel ('peak flux density Jy', fontsize=16)
    pl.savefig('peak_7m.png')
#
if 5.5 in dostep:
    print 'make dirty images of 12-m to better check line emission'
    xpcenter = pcenter
    peak12 = []
    nich = []
    for ich in range(120,170):
        str_ich = str(ich)
        Im = '12m_dirty_'+str_ich
        os.system('rm -rf '+Im+'*')
        print 'image channel ',str_ich
        tclean(vis = 'Lup_12m.ms',
               field = '',
               imagename = Im,
               spw = '*:'+str_ich,
               datacolumn = 'data',
               imsize = 256,
               deconvolver = 'hogbom',
               phasecenter = xpcenter,
               uvrange = '0~40',
               cell = '1.5arcsec',
               weighting = 'natural',
               interactive = False,
               niter = 0,
               restart = True,
               usemask = 'user')
        #
if 5.51 in dostep:
    peak12 = []
    nich = []
    for ich in range(120,170):
        str_ich = str(ich)
        Im = '12m_dirty_'+str_ich
        pl.clf()
        peak12.append(imstat(Im+'.residual')['max'][0])
        nich.append(ich)
        pl.plot(nich,peak12,'b-')
        pl.title ('peak flux density in 12m image', fontsize=18)
        pl.xlim(120,170)
        pl.xlabel ('channel number', fontsize=16)
        pl.ylabel ('peak flux density Jy', fontsize=16)
        pl.savefig('peak_12m.png')
#
if 5.6 in dostep:
    print 'compare spectrum'
    os.system('eog peak*png &')
#
if 6 in dostep:
    print 'make line cube for 7m and 12m'
    print 'first 7m data'
    os.system('rm -rf D7_line.ms')
    split(vis='Lup_7m.ms',
          spw = '*:120~170',
          outputvis = 'D7_line.ms',
          datacolumn = 'data')
    os.system('rm -rf D7_line.listobs')
    listobs(vis='D7_line.ms', listfile='D7_line.listobs')
    print 'next 12m data'
    os.system('rm -rf D12_line.ms')
    split(vis='Lup_12m.ms',
          spw = '*:120~170',
          outputvis = 'D12_line.ms',
          datacolumn = 'data')
    os.system('rm -rf D12_line.listobs')
    listobs(vis='D12_line.ms', listfile='D12_line.listobs')
#
if 7.1 in dostep:
    print 'make 7m continuum uvdata'
    flagmanager(vis='Lup_7m.ms',
                mode = 'save', versionname = 'before_continuum')
    flagdata(vis='Lup_7m.ms', spw = '*:120~170',flagbackup = False)
    split(vis='Lup_7m.ms',
          spw = '',
          outputvis = 'D7_cont.ms',
          datacolumn = 'data')
    listobs(vis='D7_cont.ms', listfile='D7_cont.listobs')
    flagmanager(vis='Lup_7m.ms',
                mode = 'restore', versionname = 'before_continuum')
if 7.2 in dostep:
    print 'make 12m continuum uvdata'
    flagmanager(vis='Lup_12m.ms',
                mode = 'save', versionname = 'before_continuum')
    flagdata(vis='Lup_12m.ms', spw = '*:120~170',flagbackup = False)
    split(vis='Lup_12m.ms',
          spw = '',
          outputvis = 'D12_cont.ms',
          datacolumn = 'data')
    listobs(vis='D12_cont.ms', listfile='D12_cont.listobs')
    flagmanager(vis='Lup_12m.ms',
                mode = 'restore', versionname = 'before_continuum')
#
if 8 in dostep:
    print 'concatenate data sets (both weightings'
    os.system('rm -rf D7+D12_line.ms')
    concat (vis=['D7_line.ms','D12_line.ms'],
            concatvis = 'D7+D12_line.ms',
            visweightscale = [1.0,1.0],
            copypointing = False)
    os.system('rm -rf D7x10+D12_line.ms')
    print 'now weight 7m by factor of 10!'
    concat (vis=['D7_line.ms','D12_line.ms'],
            concatvis = 'D7x10+D12_line.ms',
            visweightscale = [10.0,1.0],
            copypointing = False)
    os.system('rm -rf D7x10+D12_line.listobs')
    listobs(vis='D7x10+D12_line.ms', listfile = 'D7x10+D12_line.listobs')
#
if 9 in dostep:
    print 'make image of 7-m data+12data at one frequency autobox'
    for ich in range(0,51):
        Im = 'W10_UV50_ch_'+str(ich)
        os.system('rm -rf '+Im+'*')
        print 'cleaning chan ', ich
        tclean(vis = 'D7x10+D12_line.ms',
               field = '',
               imagename = Im,
               uvrange = '0~50',
               spw = '*:'+str(ich),
               datacolumn = 'data',
               imsize = 512,
               #mask = 'circle[[280pix,280pix],2pix]',
               deconvolver = 'multiscale',
               scales = [0,10],
               phasecenter = pcenter,
               cell = '0.75arcsec',
               weighting = 'briggs', robust = 0.5,
               interactive = False,
               niter = 500,
               restart = True,
               pbmask = 0.2,
               noisethreshold = 3.0,
               sidelobethreshold = 2.0,
               cutthreshold = 0.5,
               usemask = 'auto-multithresh')
    #
if 9.1 in dostep:
    print 'plot each channel'
    for ich in range(0,51):
        Im = 'W10_UV50_ch_'+str(ich)
        print 'viewing chan ', ich
        Jm = Im+'.image'
        imax = imstat(Jm)['max'][0]
        imax10 = imax/10.0
        rms = imstat(Im+'.residual')['rms'][0]
        imin = np.maximum(imax10,1*rms)
        print 'viewing chan %3d  max=%7.4f  cntr=%7.4f'%(ich, imax, imin)
        imview(raster = {'file': Jm, 'range':[-imin,imax], 'colormap': 'RGB 1'}, 
               contour = {'file':Jm,
                          'levels':[-2,-1,1,2,3,4,5,6,7,8,9],
                          'unit':float(imin)},
               zoom = 0)
#
if 10 in dostep:
    print 'split out each of the 51 TP images'
    os.system('rm -rf TP_overlap.image*')
    for ich in range(148,199):
        nch = ich - 148
        imsubimage(imagename = 'Lup_TP.image',
               outfile = 'TP_overlap.image_ch'+str(nch),
               chans = str(ich))    
#
##############################################################
if 11 in dostep:
   print 'regrid Total Power Images'
   os.system('rm -rf tp_regrid_ch*')
   for ich in range(0,50):
        Im = 'TP_overlap.image_ch'+str(ich)
        Om = 'TP_regrid_ch_'+str(ich)+'.image'
        Jm = 'W10_UV50_ch_'+str(ich)+'.image'
        os.system('rm -rf '+Om)
        imregrid(imagename = Im,
                 template = Jm,
                 output = Om,
                 axes = [0,1])
#
if 12.1 in dostep:
    for ich in [45]: # ch 10, 30, 45     
        print 'feather channel, ich'
        INTIm = 'W10_UV50_ch_'+str(ich)+'.image'
        imax1 = imstat(INTIm)['max'][0]
        icntr1 = imax1/10
        imview(contour = {'file':INTIm,
                          'levels':[-2,-1,1,2,3,4,5,6,7,8,9,12,15],
                          'unit':float(icntr1)},
               zoom = {'blc':[100,130],'trc':[380,380]})        
        TPIm = 'TP_regrid_ch_'+str(ich)+'.image'
        imax2 = imstat(TPIm)['max'][0]
        icntr2 = imax2/10
        imview(contour = {'file':TPIm,
                          'levels':[-2,-1,1,2,3,4,5,6,7,8,9],
                          'unit':float(icntr2)},
               zoom = {'blc':[100,130],'trc':[380,380]})
        outname = 'feather_sf1.0_ch_'+str(ich)
        os.system('rm -rf '+outname)
        feather(imagename=outname,
                highres = INTIm,
                lowres = TPIm,
                sdfactor = 1.0)        
        imview(contour = {'file':outname,
                          'levels':[-2,-1,1,2,3,4,5,6,7,8,9,12,15],
                          'unit':float(icntr1)},
               zoom = {'blc':[100,130],'trc':[380,380]})
