'''
*   This script was created by
*       Trevor Sommer
*           trevor@trevorsommer.com

Module containing numurous rigging functions
All these functions currently assume X is the axis pointing down the joint chain when applicable for all scaling operations
'''

import maya.cmds as cmds
from . import rigUtil as rUtil


def createRotationSwitch(drivenJoint, fkJoint, ikJoint, controller):
    '''
    Creates a rotation connection between the three passed joints with a controller attribute to blend the values between the fk and ik joints via utility nodes
    '''
    if not cmds.objExists(drivenJoint) or not cmds.objExists(fkJoint) or not cmds.objExists(ikJoint) or not cmds.objExists(controller):
        cmds.warning('One or more of the objects passed into the function \"' + createRotationSwitch + '\" does not exist in the current scene aborting operation')
        return False

    if cmds.nodeType(drivenJoint) != "joint" or cmds.nodeType(fkJoint) != "joint" or cmds.nodeType(ikJoint) != "joint":
        cmds.warning('Improper Input passed into function createRotationSwitch. \"' + drivenJoint + '\", \"' + ikJoint + '\", and \"' + fkJoint + '\" objects must all be of the type \"joint\"')
        return False

    if not cmds.attributeQuery("ik_TO_fk", exists= True, node= controller):
        #if "ik_TO_fk" attribute doesnt exist on the passed controller create it
        cmds.addAttr(controller, longName= "ik_TO_fk", attributeType= "double", keyable= True, minValue= 0, maxValue= 1, defaultValue= 0)

    '''
    Always connect to the direct attibute ".outputR >> .rotateX" not ".output >> .rotate"
    after a certain amount of critical complexity in a rig the later connection will sometimes not evaluate properly without scrubbing the time slider
    '''
    #create a utiliy node to blend the values and connect it to controllers ik_TO_fk attribute to drive the blend operation
    switchNode = cmds.shadingNode("blendColors", asUtility= True, name=(drivenJoint + "_SwitchNode_Rotate"))

    cmds.connectAttr((fkJoint + ".rotateX"), (switchNode + ".color1R"), force= True)
    cmds.connectAttr((fkJoint + ".rotateY"), (switchNode + ".color1G"), force= True)
    cmds.connectAttr((fkJoint + ".rotateZ"), (switchNode + ".color1B"), force= True)

    cmds.connectAttr((ikJoint + ".rotateX"), (switchNode + ".color2R"), force= True)
    cmds.connectAttr((ikJoint + ".rotateY"), (switchNode + ".color2G"), force= True)
    cmds.connectAttr((ikJoint + ".rotateZ"), (switchNode + ".color2B"), force= True)

    cmds.connectAttr((switchNode + ".outputR"), (drivenJoint + ".rotateX"), force= True)
    cmds.connectAttr((switchNode + ".outputG"), (drivenJoint + ".rotateY"), force= True)
    cmds.connectAttr((switchNode + ".outputB"), (drivenJoint + ".rotateZ"), force= True)

    cmds.connectAttr((controller + ".ik_TO_fk"), (switchNode + ".blender"), force= True)
    return switchNode

def createScaleSwitch(drivenJoint, fkJoint, ikJoint, controller):
    '''
    Creates a scale connection between the three passed joints with a controller attribute to blend the values between the fk and ik joints via utility nodes
    '''
    if not cmds.objExists(drivenJoint) or not cmds.objExists(fkJoint) or not cmds.objExists(ikJoint) or not cmds.objExists(controller):
        cmds.warning('One or more of the objects passed into the function \"' + createScaleSwitch + '\" does not exist in the current scene aborting operation')
        return False

    if cmds.nodeType(drivenJoint) != "joint" or cmds.nodeType(fkJoint) != "joint" or cmds.nodeType(ikJoint) != "joint":
        cmds.warning('Improper Input passed into function createScaleSwitch. \"' + drivenJoint + '\", \"' + ikJoint + '\", and \"' + fkJoint + '\" objects must all be of the type \"joint\"')
        return False

    if not cmds.attributeQuery("ik_TO_fk", exists= True, node= controller):
        #if "ik_TO_fk" attribute doesnt exist on the passed controller create it
        cmds.addAttr(controller, longName= "ik_TO_fk", attributeType= "double", keyable= True, minValue= 0, maxValue= 1, defaultValue= 0)

    #create a utiliy node to blend the values and connect it to controllers ik_TO_fk attribute to drive the blend operation
    switchNode = cmds.shadingNode("blendColors", asUtility= True, name=(drivenJoint + "_SwitchNode_Scale"))
    cmds.connectAttr((ikJoint + ".scaleX"), (switchNode + ".color2R"), force= True)
    cmds.connectAttr((fkJoint + ".scaleX"), (switchNode + ".color1R"), force= True)
    cmds.connectAttr((switchNode + ".outputR"), (drivenJoint + ".scaleX"), force= True)
    cmds.connectAttr((controller + ".ik_TO_fk"), (switchNode + ".blender"), force= True)
    return switchNode

def distroSubRig(baseJoint, targetJoint, numJoints, charPfx, connectionMethod, armLegDesignation):
    '''
    Creates X number of rotation distribution joints between base and target Joints
    aligning and evenly spacing them and connecting there rotation values appropriately via utility nodes
    '''

    if not isinstance(armLegDesignation, str) or not isinstance(charPfx, str) or not isinstance(connectionMethod, str):
        cmds.warning('Improper Input passed into function distroSubRig. \"' + armLegDesignation + '\", \"' + connectionMethod + '\", and \"' + charPfx + '\" must all be of type string')
        return False

    if connectionMethod.lower() not in ['up', 'down']:
        cmds.warning('Improper Input passed into function distroSubRig. \"' + connectionMethod + '\" can only be \"up\" or \"down\"')
        return False

    if cmds.nodeType(baseJoint) != "joint" or cmds.nodeType(targetJoint) != "joint":
        cmds.warning('Improper Input passed into function distroSubRig. \"' + baseJoint + '\" and \"' + targetJoint + '\" objects must all be of the type \"joint\"')
        return False

    if not isinstance(numJoints, int):
        cmds.warning('Improper Input passed into function distroSubRig. \"' + numJoints + '\" must be a positive integer')
        return False

    #variable to hold and then return created joints names
    createdJoints = []

    #distro variables
    weight_Factor = (1 / (numJoints + 1))
    current_Weight = 0
    remainder = 1

    for i in range(numJoints):
        #clear selection at the start of each loop to be safe
        cmds.select(clear= True)
        #name for final joint this iteration will create
        new_jointName = (charPfx + "_" + armLegDesignation + "_" + connectionMethod + "_" + (numJoints - i) + "_JNT")

        if cmds.objExists(new_jointName):
            cmds.warning('An object already exists with the name ' + new_jointName + ' canceling distroSubRig operation')
            return False

        #calculate this joints weight factors
        current_Weight += weight_Factor
        remainder = 1 - current_Weight

        #create new joint 
        cmds.joint(name= new_jointName, position= [0,0,0])

        #point and orient constrain it to its base and target joints
        pointConName = cmds.pointConstraint(baseJoint, targetJoint, new_jointName, offset= [0,0,0], weight= 1.0)
        #delete the orient constraint as soon as its made we just need to align it with the base joint anyway
        cmds.delete(cmds.orientConstraint(baseJoint, new_jointName, offset= [0,0,0], weight= 1.0))

        #set weights of point constraint to evenly place joints with each iteration then delete the constriant
        cmds.setAttr((pointConName + "." + baseJoint + "W0"), current_Weight)
        cmds.setAttr((pointConName + "." + targetJoint + "W1"), remainder)
        cmds.delete(pointConName)

        #freeze the joints transformations 
        cmds.makeIdentity(new_jointName, apply= True, scale= True, rotate= True, translate= True, normal= True)

        #parent to baseJoint which should be the upper joint in the chain
        cmds.parent(new_jointName, baseJoint) 

        '''
        Connect each new_jointName rotations to its master joint based off connectionMethod
        '''
        if connectionMethod.lower() == "down":
            #make sure Rotate values are all zero
            cmds.setAttr((new_jointName + ".rotateX"), 0)
            cmds.setAttr((new_jointName + ".rotateY"), 0)
            cmds.setAttr((new_jointName + ".rotateZ"), 0)
            
            #create a multiDivide node and connect it to the rotate X attribute of this iterations joint with the proper distribution percent
            #rotateX is driven by the targetJoint not the baseJoint in this method
            divMulti = cmds.shadingNode('multiplyDivide', asUtility= True, name=(new_jointName + "_MULTI"))
            cmds.setAttr((divMulti + ".input2X"), remainder)
            cmds.connectAttr((targetJoint + ".rotateX"), (divMulti + ".input1X"), force= True)
            cmds.connectAttr((divMulti + ".outputX"), (new_joint + ".rotateX"), force= True)
        else:
            '''
            The up version of arm and leg distribution (shoulder/hip) requires a rotation master joint 
            This is to isolate the x rotations properly in the joint chain
            '''
            rotation_Master = (baseJoint + "_Rotation_Master")
            if not cmds.objExists(rotation_Master):
                cmds.duplicate(baseJoint, parentOnly= True, name= rotation_Master)
                cmds.parent(targetJoint, rotation_Master)
                '''
                Double check to see if inverse scale attributes needs to be connected here or not after running
                '''
            #make sure Rotate values are all zero
            cmds.setAttr((new_jointName + ".rotateX"), 0)
            cmds.setAttr((new_jointName + ".rotateY"), 0)
            cmds.setAttr((new_jointName + ".rotateZ"), 0)
            
            #create a multiDivide node and connect it to the rotate X attribute of this iterations joint with the proper distribution percent
            #rotateX is driven by the rotation Master because in the final rig the baseJoint rotate x should always be 0
            divMulti = cmds.shadingNode('multiplyDivide', asUtility= True, name=(new_jointName + "_MULTI"))
            cmds.setAttr((divMulti + ".input2X"), remainder)
            cmds.connectAttr((rotation_Master + ".rotateX"), (divMulti + ".input1X"), force= True)
            cmds.connectAttr((divMulti + ".outputX"), (new_joint + ".rotateX"), force= True)

        createdJoints.append(new_jointName)

    return createdJoints

def duplicateJointHierarchy(baseJoint, dupType, charPfx = ""):
    '''
    Duplicates the whole joint chain hierarchy underneath the baseJoint that is passed 
    renames every joint it duplicates by adding the dupType string after the passed charPfx
    Example: duplicateJointHierarchy("pig_joint1", "pig", "fk") results in a joint named "pig_fk_joint1"
    '''
    if not dupType:
        cmds.warning("Invalid dupType passed into duplicateJointHierarchy, must pass a valid string")
        return False

    #create a list to append and return the names of all newly created joints
    duplicatedJoints = []

    cmds.select(baseJoint, replace= True, hierarchy= True)
    originalJointList = cmds.ls(selection= True, flatten= True, type= "joint")
    if not originalJointList:
        cmds.warning("duplicateJointHierarchy found no valid joints to duplicate")
        return False

    dupJointList = cmds.duplicate(baseJoint, renameChildren= True)
    if len(originalJointList) != len(dupJointList):
        cmds.warning("duplicateJointHierarchy has mismatched joint lists after duplication operation canceling")
        return False

    for idx in range(len(dupJointList)):
        if not charPfx:
            #if no charPfx is designated will use the dupType as a prefix
            dupJointName = dupType + "_" + originalJointList[idx]
        else:
            #function assumes charPfx is always followed by "_"
            dupJointName = originalJointList[idx].replace(charPfx, (charPfx + "_" + dupType))

        duplicatedJoints.append(cmds.rename(dupJointList[idx], dupJointName))

    return duplicatedJoints

def fkControllerScalable(fkJoint):
    '''
    Creates an Fk Controller for the passed Joint that allows for fk based joint scaling
    '''
    if cmds.nodeType(fkJoint) != "joint":
        cmds.warning('Improper Input passed into function fkControllerScalable. \"' + fkJoint + '\" must be of the type \"joint\"')
        return False

    fkControllerName = fkJoint + "_FK_CON"
    if cmds.objExists(fkControllerName):
        cmds.warning('Controller object \"' + fkControllerName + '\" already exists canceling execution of fkControllerScalable.')
        return False

    #create joint scale offset for stabilization of rig
    offsetJoint = cmds.duplicate(fkJoint, parentOnly= True, name= (fkJoint + "_CON_STABILIZER"))
    cmds.parent(fkJoint, offsetJoint)
    #create FK controller offset group
    controllerGRP = cmds.group(empty= True, name= (fkJoint + "_CON_OFFSET_GRP"))
    cmds.parentConstraint(offsetJoint, controllerGRP, weight= 1)

    #create nurbsCircle as fk controller object position it to the fkJoint then parent it properly
    fkControllerName = cmds.circle(center= [0,0,0], normal= [1,0,0], sweep= 360, radius= 1, degree= 3, useTolerance= 0, sections= 8, constructionHistory= True, name= fkControllerName)
    cmds.delete(cmds.parentConstraint(fkJoint, fkControllerName, weight= 1))
    cmds.parent(fkControllerName, controllerGRP)

    #freeze the fkControllerName transformations so they match the orientation of the fkJoint
    cmds.makeIdentity(fkControllerName, apply= True, scale= True, rotate= True, translate= True, normal= True)
    cmds.delete(fkControllerName, constructionHistory= True)

    #connect the fkControllerName rotations and scaleX to drive the fkJoint 
    cmds.connectAttr((fkControllerName + ".rotateX"), (fkJoint + ".rotateX"), force= True)
    cmds.connectAttr((fkControllerName + ".rotateY"), (fkJoint + ".rotateY"), force= True)
    cmds.connectAttr((fkControllerName + ".rotateZ"), (fkJoint + ".rotateZ"), force= True)

    cmds.connectAttr((fkControllerName + ".scaleX"), (fkJoint + ".scaleX"), force= True)

    '''
    Check that the inversScale Attributes are all connected properly after operation they sometimes dont by default
    this ensures that we dont get shearing of the controllers or geometry when animating the scales
    '''

    #get parent of offsetJoint
    incomingConnection = cmds.listRelatives(offsetJoint, parent= True, type= 'joint')

    if not cmds.isConnected((offsetJoint + ".scale"), (fkJoint + ".inverseScale")):
        cmds.connectAttr((offsetJoint + ".scale"), (fkJoint + ".inverseScale"), force= True)

    if not cmds.isConnected((incomingConnection[0] + ".scale"), (offsetJoint + ".inverseScale")):
        cmds.connectAttr((incomingConnection[0] + ".scale"), (offsetJoint + ".inverseScale"), force= True)

    #lock and hide all pertinent attributes in the controller hierarchy 
    rUtil.attrLocker(controllerGRP, ["tx","ty","tz","rx","ry","rz","sx","sy","sz","v"])
    rUtil.attrLocker(fkControllerName, ["tx","ty","tz","sy","sz","v"])

    #return the created controllers name
    return fkControllerName

def simpleStretchIKSetup(handle, controller):
    '''
    Creates stretch and squash setup on an ik chain based of the passed handle
    Places squash and stretch attributes on the passed controller to turn the options on and off
    '''
    if not cmds.objExists(handle) or not cmds.objExists(controller):
        cmds.warning('One or more of the objects passed into ssIKSetup does not exists')
        return False

    if cmds.nodeType(handle) != "ikHandle":
        cmds.warning('Handle object \"' + handle + '\" is not of type ikHandle stretchy setup will not be created')
        return False

    #Create stretch toggles and offset attribute onto the designated controller object
    if not cmds.attributeQuery("stretchy_Attributes", exists= True, node= controller):
        cmds.addAttr(controller, attributeType= 'bool', longName= "stretchy_Attributes",hidden= False,  keyable= True, minValue= 0, maxValue= 1, defaultValue= 0)
        cmds.setAttr((controller + ".stretchy_Attributes"), True, lock= True)

    if not cmds.attributeQuery("stretch", exists= True, node= controller):
        cmds.addAttr(controller, attributeType= 'double', longName= "stretch",hidden= False,  keyable= True, minValue= 0, maxValue= 1, defaultValue= 0)

    #get Joint list from handle and effector 
    jointNames = cmds.ikHandle(handle, query= True, jointList= True)
    effectorName = cmds.ikHandle(handle, query= True, endEffector= True)
    jointNames.append(cmds.listConnections(effectorName, type= 'joint')[0])
    
    #create utility nodes for stretch setup  
    distanceBetween = cmds.shadingNode("distanceBetween", asUtility= True, name=(handle + "_stretch_DB"))
    stretchConditional = cmds.shadingNode("condition", asUtility= True, name=(handle + "_stretch_COND"))
    stretchMulti = cmds.shadingNode("multiplyDivide", asUtility= True, name=(handle + "_stretch_MD"))
    stretchBlend = cmds.shadingNode("blendColors", asUtility= True, name=(handle + "_stretch_BLEND"))

    baseLoc = cmds.group(name= (handle + '_BASE_LOC'), empty= True)
    endLoc = cmds.group(name= (handle + '_END_LOC'), empty= True)

    #constrain the base and end loc null groups for distance calculations
    cmds.pointConstraint(handle, baseLoc, maintainOffset= False)
    cmds.pointConstraint(jointNames[0], endLoc, maintainOffset= False)

    #connect rest of utility nodes
    cmds.connectAttr((baseLoc + '.translate'), (distanceBetween + '.point2'), force= True)
    cmds.connectAttr((endLoc + '.translate'), (distanceBetween + '.point1'), force= True)

    cmds.connectAttr((distanceBetween + '.distance'), (stretchConditional + '.secondTerm'), force= True)
    cmds.connectAttr((distanceBetween + '.distance'), (stretchConditional + '.colorIfFalseR'), force= True)
    cmds.setAttr((stretchConditional + '.operation'), 3) 
    
    cmds.connectAttr((stretchConditional + '.outColorR'), (stretchMulti + '.input1X'), force= True)
    cmds.setAttr((stretchMulti + '.operation'), 2)

    cmds.connectAttr((stretchMulti + '.outputX'), (stretchBlend + '.color1R'), force= True)
    cmds.connectAttr((controller + '.stretch'), (stretchBlend + '.blender'), force= True)
    cmds.setAttr((stretchBlend + ".color2"), 1, 1, 1, type= 'double3') 
    
    #cycle through joints and calculate total distance of chain
    chainDistance = 0
    for idx in range(1, len(jointNames)):
        pos1 = cmds.pointPosition(jointNames[idx - 1] + '.rotatePivot')
        pos2 = cmds.pointPosition(jointNames[idx] + '.rotatePivot')
        chainDistance += rUtil.vectDist(pos1, pos2)

    #Use total distance at rest as the stretch setups base stretch value
    cmds.setAttr((stretchConditional+".firstTerm"), chainDistance) 
    cmds.setAttr((stretchConditional+".colorIfTrueR"), chainDistance) 
    cmds.setAttr((stretchMulti+".input2X"), chainDistance) 

    #connect scale into joints scaleX
    del jointNames[-1]
    for joint in jointNames:
        cmds.connectAttr((stretchBlend + '.outputR'), (joint + '.scaleX'), force= True)

    #hierarchy Cleanup for setup nodes
    if not cmds.objExists('scale_setup_GRP'):
        cmds.group(empty= True, name= 'scale_setup_GRP')

    cmds.parent(baseLoc, endLoc, 'scale_setup_GRP')
    cmds.setAttr("scale_setup_GRP.visibility", 0)

def rig_SpaceCreator(child, parentName, spaceName = "", otherControls = []):
    '''
    initializes a space switching setup to pieces of a rig as necessary by controller 
    otherControls must be a list of strings representing child controllers/objects that should be connected and share the same space if applicable
    '''

    spaceOffset = (child + "_SPACE_OFFSET")
    spaceConstraint = (childs + "_SpaceSwitcher")
    availableSpaces =[]
    controllerSet = child

    #create children attributes with dummy/cosmetic conections to easily manage all controls switch should work on
    if otherControls and not cmds.attributeQuery("children", exists= True, node= child):
        cmds.addAttr(child, longName='spaceChildren', numberOfChildren= len(otherControls), attributeType='compound' )

        for item in otherControls:
            if cmds.objExists(item):
                cmds.addAttr(child, longName= item, kayable= False, dataType='string', parent='spaceChildren')
                controllerSet.append(item)
                cmds.connectAttr((item + ".message"), (child + "." + item), force= True)

    #create space offset node
    if not cmds.objExists(spaceOffset):
        cmds.group(empty= True, name= spaceOffset)
        for con in controllerSet:
            if not cmds.objExists((con + "_Main_OFFSET")):
                cmds.group(con, name= (con + "_Main_OFFSET"))
            cmds.parent((con + "_Main_OFFSET"), spaceOffset)
            
    
    if cmds.attributeQuery("space", exists= True, node= child):
        availableSpacesString = cmds.addAttr((child + ".space"), query= True, enumName= True)
        if spaceName in availableSpacesString.split(":"):
            cmds.warning('The spaceName \"' + spaceName + '\" already exists on the passed controller')
            return False

    #if "init" is passed as space attribute to the controller and nothing else
    if spaceName.lower() == "init":
        cmds.addAttr(child, keyable= True, longName= "space", attributeType= "enum", enumName= "world:")
    else:
        restList = ["restTranslateZ", "restTranslateX", "restTranslateY", "restRotateX", "restRotateY", "restRotateZ"]

        if availableSpacesString:
            spaceConstraint = cmds.parentConstraint(parentName, spaceOffset, weight= 1)
        else:
            cmds.addAttr(child, keyable= True, longName= "space", attributeType= "enum", enumName= "world:")
            availableSpacesString = "world:"
            spaceConstraint = cmds.parentConstraint(parentName, spaceOffset, weight= 1, name= spaceConstraint)

        #zero out any constraint rest offsets that may be there
        for restListItem in restList:
            cmds.setAttr((spaceConstraint + "." + restListItem), 0)

        #add newly created space enum
        if not spaceName:
            #if no spacename is passed use the "parent" objects name as the space name
            spaceName = parentName

        availableSpacesString = (availableSpacesString + ":" + spaceName + ":")
        cmds.addAttr(edit= True, enumName= (availableSpacesString))

        spacesList = availableSpacesString.split(":")
        numberOSpaces = len(spacesList)
        spaceCondNode = (child + "_" + spaceName + "_Space_COND")


        spaceCondNode = cmds.shadingNode("condition", asUtility= True, name= spaceCondNode)
        cmds.setAttr((spaceCondNode + ".colorIfTrueR"), 1)
        cmds.setAttr((spaceCondNode + ".colorIfFalseR"), 0)
        cmds.setAttr((spaceCondNode + ".secondTerm"), numberOSpaces)
        cmds.setAttr((spaceCondNode + ".operation"), 0)

        cmds.connectAttr((child + ".space"), (spaceCondNode + ".firstTerm"))
        cmds.connectAttr((spaceCondNode + ".outColorR"), (spaceConstraint + "." + parentName + "W" + str(numberOSpaces-1)))



