import sys
import maya.OpenMayaMPx as omMPx
import maya.OpenMaya as om
import maya.cmds as cmds

'''
Warning: There is a bug with Maya's 2018 Python Api 
editing membership on deformer plugins has caused nodes to delete themselves, or crash Maya.
'''

'''=======================================================
## Custom Plug-in Class
======================================================='''
class displace(omMPx.MPxDeformerNode):

    ##NOTE:plugin Registration Info
    kPluginNodeName = 'displace' 
    kPluginNodeId = om.MTypeId(0xBEEF8)

    #NOTE: plugin constants
    K_INPUT = omMPx.cvar.MPxGeometryFilter_input
    K_INPUT_GEOM = omMPx.cvar.MPxGeometryFilter_inputGeom
    K_OUTPUT_GEOM = omMPx.cvar.MPxGeometryFilter_outputGeom
    K_ENVELOPE = omMPx.cvar.MPxGeometryFilter_envelope

    #Static variable(s) which are stand ins for the node's attribute(s).
    magnitude_attr = om.MObject()
    #sourceSurface_attr = om.MObject()


    def __init__(self):
        ''' Constructor. '''
        # (!) Make sure you call the base class's constructor.
        omMPx.MPxDeformerNode.__init__(self)

    '''=======================================================
    ## class initialization.
    ======================================================='''
    @classmethod
    def nodeCreator(displace):
        ''' Creates an instance of our node class and delivers it to Maya as a pointer. '''
        return omMPx.asMPxPtr(displace())

    @classmethod
    def nodeInitializer(displace):
        ''' Defines the input and output attributes as static variables in our plug-in class. '''
        numericAttributeFn = om.MFnNumericAttribute()
        
        #==================================
        # INPUT NODE ATTRIBUTE(S)
        #==================================
        # Define the mesh displacement magnitude attribute, 
        displace.magnitude_attr = numericAttributeFn.create('magnitude', 'mag', om.MFnNumericData.kDouble, 0.0)
        numericAttributeFn.setStorable(True)
        numericAttributeFn.setWritable(True)
        numericAttributeFn.setKeyable(True)
        displace.addAttribute(displace.magnitude_attr)
      
        #==================================
        # OUTPUT NODE ATTRIBUTE(S)
        #==================================

        #==================================
        # NODE ATTRIBUTE DEPENDENCIES
        #==================================
        displace.attributeAffects(displace.magnitude_attr, displace.K_OUTPUT_GEOM)
        # Make deformer weights paintable
        cmds.makePaintable(displace.kPluginNodeName, 'weights', attrType='multiFloat', shapeMode='deformer')

    '''=======================================================
    ## Deformation Operation
    ======================================================='''
    def getDeformerInputGeometry(self, pDataBlock, pGeometryIndex):
        inputAttribute = omMPx.cvar.MPxGeometryFilter_input
        inputGeometryAttribute = omMPx.cvar.MPxGeometryFilter_inputGeom
        
        inputHandle = pDataBlock.outputArrayValue(inputAttribute)
        inputHandle.jumpToElement(pGeometryIndex)
        inputGeometryObject = inputHandle.outputValue().child(inputGeometryAttribute).asMesh()
        
        return inputGeometryObject        
    
    def deform(self, pDataBlock, pGeometryIterator, pLocalToWorldMatrix, pGeometryIndex):
        envAttr = self.K_ENVELOPE
        envelopeValue = pDataBlock.inputValue(self.envelope).asFloat()
        
        #Get the value of the mesh inflation node attribute.
        magnitudeHandle = pDataBlock.inputValue(self.magnitude_attr)
        magnitude = magnitudeHandle.asDouble()
        
        # Get the input mesh from the datablock using our getDeformerInputGeometry() helper function.     
        inputGeometryObject = self.getDeformerInputGeometry(pDataBlock, pGeometryIndex)
        if not inputGeometryObject:
            return 

        # Obtain the list of normals for each vertex in the mesh.
        normals = om.MFloatVectorArray()
        meshFn = om.MFnMesh(inputGeometryObject)
        meshFn.getVertexNormals(True, normals, om.MSpace.kTransform)

        #grab all the point positions
        allPointPositions = om.MPointArray()
        pGeometryIterator.allPositions(allPointPositions)

        # Iterate over the vertices to move them.
        while not pGeometryIterator.isDone():

            # Obtain the vertex normal of the geometry
            idx = pGeometryIterator.index()
            normal = om.MVector(normals[idx])

            weightValue = self.weightValue(pDataBlock, pGeometryIndex, idx)
            weightValue = weightValue * envelopeValue

            # Increment the point along the vertex normal.
            point = allPointPositions[idx]
            newPos = point + (normal * magnitude * weightValue)
            allPointPositions.set(newPos, idx)
            
            # Jump to the next vertex.
            pGeometryIterator.next()

        #set all the positions at the end
        pGeometryIterator.setAllPositions(allPointPositions)
    
'''=======================================================
## Plug-in initialization.
======================================================='''

def initializePlugin(mobject):
    mplugin = omMPx.MFnPlugin(mobject)
    try:
        mplugin.registerNode(   displace.kPluginNodeName, 
                                displace.kPluginNodeId, 
                                displace.nodeCreator, 
                                displace.nodeInitializer, 
                                omMPx.MPxNode.kDeformerNode)
    except:
        sys.stderr.write('Failed to register node: {}'.format(kPluginNodeName))
        raise
    
def uninitializePlugin(mobject):
    mplugin = omMPx.MFnPlugin(mobject)
    try:
        mplugin.deregisterNode(displace.kPluginNodeId)
    except:
        sys.stderr.write('Failed to deregister node: {}'.format(kPluginNodeName))
        raise





