# -*- coding: utf-8 -*- 
###########################################################################
# Eole NG - 2007
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
###########################################################################

"""
Agent zephir de surveillance des connexions (ip conntrack)
"""

from twisted.internet import defer
from twisted.internet.utils import getProcessOutput
from socket import ntohs, ntohl
from IPy import IP

from zephir.monitor.agentmanager.agent import Agent
from zephir.monitor.agentmanager import status
from zephir.monitor.agentmanager.data import HTMLData, TableData
from zephir.monitor.agentmanager.util import percent

from pyctd.conntracking import DataCollector
from pynetfilter_conntrack import constant

IPPROTO_NAMES = {}
IPPROTO_FAMILIES = {}
TCPSTATES = {}

# Python dictionnary to convert IP protocol to string
for attr_name in dir(constant):
    if attr_name.startswith('IPPROTO_'):
        IPPROTO_NAMES[getattr(constant,attr_name)] = attr_name[8:]
    if attr_name.startswith('PF_'):
        IPPROTO_FAMILIES[getattr(constant,attr_name)] = attr_name[3:]
    if attr_name.startswith('NFCT_TCP_ST_'):
        TCPSTATES[getattr(constant,attr_name)] = attr_name[12:]

class Connexions(Agent):

    def __init__(self, name,
                 **params):
        Agent.__init__(self, name, **params)
        self.table = TableData([
            ('id', 'Id_cnx', {'align':'right'}, None),
            ('user', 'Utilisateur', {'align':'left'}, None),
            ('rate_in', 'Débit Entrant', {'align':'right'}, None),
            ('rate_out', 'Débit Sortant', {'align':'right'}, None),
            ('port', 'Port', {'align':'left'}, None),
            ('src', 'Source', {'align':'right'}, None),
            ('dst', 'Destination', {'align':'right'}, None),
            ('status', 'Etat', {'align':'left'}, None),
            ('proto', 'Protocole', {'align':'left'}, None),
            ('mark', 'Marquage', {'align':'left'}, None),
            ('timeout', 'Timeout', {'align':'left'}, None),
            ])
        self.data = [self.table]
        self.collector = None

    def measure(self):
        if self.collector == None:
            self.collector = DataCollector('conntrack', None)
        self.collector.refresh()
        meas_data = []
        conn_set = self.collector.current
        # connections = conn_set.itervalues()
        for id, conn in conn_set.items():
            conn_data = conn.conntrack
            fields = {}
            fields['id'] = id
            fields['mark'] = conn_data.mark
            fields['user'] = conn.username
            fields['timeout'] = str(int(conn_data.timeout))
            states = []
            # FIXME  : inverser les lignes si bug sur status corrigé
            # for state in str(conn_data.status):
            #for state in str(ntohl(conn_data.status)):
            #    states.append(TCPSTATES[int(state)])
            fields['status'] = TCPSTATES[conn_data.tcp_state]
            fields['src'] = str(IP(conn_data.orig_ipv4_src))
            fields['dst'] = str(IP(conn_data.orig_ipv4_dst))
            if conn_data.orig_l3proto == constant.IPPROTO_ICMP:
                fields["l3src"] = int(conn_data.icmp_id)
                fields["l4src"] = int(conn_data.icmp_type)
            else:
                try:
                    l3src = str(IPPROTO_FAMILIES[conn_data.orig_l3proto])
                except KeyError:
                    l3src = str(conn_data.orig_l3proto)
                try:
                    l4src = str(IPPROTO_NAMES[conn_data.orig_l4proto])
                except KeyError:
                    l4src = str(conn_data.orig_l4proto)
                fields["l3src"] = l3src
                fields["l4src"] = l4src
            fields['proto'] = "%s/%s" % (fields["l3src"],fields["l4src"])
            fields['port'] = "%s/%s" % (str(int(conn_data.orig_port_src)),str(int(conn_data.orig_port_dst)))
            # FIXME : afficher en % de la b.p disponible
            # ou en % par rapport à l'ensemble des connexions ?
            fields['rate_in'] = float(conn.orig_byterate)
            fields['rate_out'] = float(conn.repl_byterate)
            meas_data.append(fields)
        return {'statistics': meas_data}

    def check_status(self):
        stat = status.OK()
        return stat

    def save_measure(self, measure):
        Agent.save_measure(self, measure)
        self.last_measure = measure
        
    def write_data(self):
        Agent.write_data(self)
        if self.last_measure is not None:
            self.table.table_data = self.last_measure.value['statistics']

