##| Copyright: (C) 2019-2020 Kevin Larke <contact AT larke DOT org> 
##| License: GNU GPL version 3.0 or above. See the accompanying LICENSE file.
import rtmidi
import rtmidi.midiutil

from result import Result

class MidiDevice(object):
    def __init__(self, **kwargs ):
        self.mip          = None
        self.mop          = None
        self.inMonitorFl  = False
        self.outMonitorFl = False
        self.throughFl    = False
        self.inPortLabel  = None
        self.outPortLabel = None
        self.setup(**kwargs)

    def setup( self, **kwargs  ):

        res = Result()

        if kwargs is None:
            return res

        if 'inPortLabel' in kwargs:
            res += self.select_port( True, kwargs['inPortLabel'] )
            
        if 'outPortLabel' in kwargs:
            res += self.select_port( False, kwargs['outPortLabel'] )

        if 'inMonitorFl' in kwargs:
            self.enable_monitor( True, kwargs['inMonitorFl'] )

        if 'outMonitorFl' in kwargs:
            self.enable_monitor( True, kwargs['outMonitorFl'] )

        if 'throughFl' in kwargs:
            self.enable_through( kwargs['throughFl'] )

        return res

    def _clean_port_label( self, portLabel ):
        return ' '.join(portLabel.split(' ')[:-1])
        
    def _get_port_list( self, inDirFl ):
        dev   =  rtmidi.MidiIn() if inDirFl else rtmidi.MidiOut()

        # get port list and drop the numeric id at the end of the port label
        return [ self._clean_port_label(p)  for p in dev.get_ports() ]

        
        
    def get_port_list( self, inDirFl ):
        return { 'type':'midi',
                 'dir': 'in' if inDirFl else 'out',
                 'op':  'list',
                 'listL': self._get_port_list( inDirFl )              
        }

    def get_in_port_list( self ):
        return self.get_port_list( True )

    def get_out_port_list( self ):
        return self.get_port_list( False )
    
    def select_port( self, inDirFl, portLabel ):
        res = Result()
        
        if portLabel:

            dirLabel = "input" if inDirFl else "output"

            portL = self._get_port_list( inDirFl )

            if portLabel not in portL:
                res.set_error("The port '%s' is not an available %s port." % (portLabel,dirLabel))
            else:
                port_idx = portL.index(portLabel)  # TODO error check

                if inDirFl:
                    self.mip,self.inPortLabel = rtmidi.midiutil.open_midiinput(port=port_idx)
                    self.inPortLabel = self._clean_port_label(self.inPortLabel)
                else:
                    self.mop,self.outPortLabel = rtmidi.midiutil.open_midioutput(port=port_idx)
                    self.outPortLabel = self._clean_port_label(self.outPortLabel)

        return res

    def select_in_port( self, portLabel ):
        return self.select_port( True, portLabel )

    def select_out_port( self, portLabel ):
        return self.select_port( False, portLabel )
    
    def enable_through( self, throughFl ):
        self.throughFl = throughFl

    def enable_monitor( self, inDirFl, monitorFl ):
        if inDirFl:
            self.inMonitorFl = monitorFl
        else:
            self.outMonitorFl = monitorFl

    def enable_in_monitor( self, monitorFl):
        self.enable_monitor( True, monitorFl )
        
    def enable_out_monitor( self, monitorFl):
        self.enable_monitor( False, monitorFl )

    def port_name( self, inDirFl ):
        return inPortLabel if inDirFl else outPortLabel

    def in_port_name( self ):
        return self.port_name(True)

    def out_port_name( self ):
        return self.port_name(False)

    def _midi_data_to_text_msg( self, inFl, midi_data ):

        text = ""
        if len(midi_data) > 0:
            text += "{:02x}".format(midi_data[0])
            
        if len(midi_data) > 1:
            text += " {:3d}".format(midi_data[1])

        if len(midi_data) > 2:
            text += " {:3d}".format(midi_data[2])

        text = ("in:  " if inFl else "out: ") + text
        return { 'type':'midi', 'dir':inFl, 'op':'monitor', 'value':text }
    
    def get_input( self ):
        o_msgL = []
        
        if self.mip is not None:
            midi_msg = self.mip.get_message()
            if midi_msg and midi_msg[0]:
                
                if self.monitorInFl:
                    o_msgL.append( self._midi_data_to_text_msg(True,midi_msg[0]) )
                    
                if self.throughFl and self.mop is not None:
                    self.mop.send_message(midi_msg[0])
                                   
                    
                o_msgL.append( { 'type':'midi', 'op':'data', 'dir':'in', 'value':midi_msg[0] } )
                    
        return o_msgL
        

    def send_output( self, m ):
        o_msgL = []
        
        if self.mop is not None:
            self.mop.send_message(m)
                                   
        if self.outMonitorFl:
            o_msgL += [self._midi_data_to_text_msg( False, m )]
                                   
        return o_msgL                           

    def send_note_on( self, pitch, vel, ch=0 ):
        return self.send_output( [ 0x90+ch, pitch, vel ] )

    def send_note_off( self, pitch, ch=0 ):
        return self.send_note_on( 0, ch )

    def send_pgm_change( self, pgm, ch=0 ):
        return self.send_output( [ 0xc0+ch, pgm ] )

    def send_pbend( self, val, ch=0 ):
        assert( val < 8192 )

        ival = int(val)

        lsb = ival & 0x7f
        msb = (ival >> 7) & 0x7f

        return self.send_output( [ 0xe0+ch, lsb, msb ] )
        

    def send_controller( self, num, value, ch=0 ):
        return self.send_output( [0xb0+ch, num, value ] )

    def send_all_notes_off(self, ch=0 ):
        return self.send_controller( 123, 1, ch=ch )

    def get_state( self ):
        return  {
            "inMonitorFl":self.inMonitorFl,
            "outMonitorFl":self.outMonitorFl,
            "throughFl":self.throughFl,
            "inPortLabel":self.inPortLabel,
            "outPortLabel":self.outPortLabel,
            "inPortL":self.get_in_port_list(),
            "outPortL":self.get_out_port_list()
        }
        
        
    def on_command( self, m, ms ):
        errL = []
        
        if m.type == 'midi':
            if m.op == 'sel':
                errL.append(self.select_port( m.dir=='in', m.value ))
                
            elif m.op == 'through':
                self.enable_through( m.value )
                
            elif m.op == 'monitor':
                self.enable_monitor( m.dir=='in', m.value )

        return errL

if __name__ == "__main__":

    md = MidiDevice()

    print(md.get_port_list( True ))
    print(md.get_port_list( False ))