from PySide2.QtWidgets import *
import maya.cmds as cmds
from BaseUI import BaseUI

# Consts
VALID_ROTATE_ORDERS = ["xyz", "yzx", "zxy", "xzy", "yxz", "zyx"]
DOCUMENTATION_URL = "www.google.com"

ui = BaseUI('General Rigging Tools')

ui.add_dropdown("Joint Draw Style", ["Bone", "Multi-child as Box", "None"], lambda index: set_hierarchy_draw_style(index))
ui.add_slider("Joint Radius", 1, 10, lambda value: set_joint_radius(value))
ui.add_dropdown("HIK Joint Look Style", ["None", "Bone", "Stick", "Box", "Circle", "Square"], lambda index: set_hierarchy_look_style(index))
ui.add_dropdown("Generate Shape", ["Circle", "Square"], lambda index: generate_shape(index))
ui.add_checkbox("Toggle Local Axis", lambda state: set_local_axis(state))
ui.add_dropdown("Rotate Order", ["xyz", "yzx", "zxy", "xzy", "yxz", "zyx"], lambda index: set_rotate_order(VALID_ROTATE_ORDERS[index]))

# ui.add_button("Help", 5, lambda: ui.open_website(DOCUMENTATION_URL))
ui.add_close_button(10)

# Function to set the draw style of a bone and its descendants
def set_draw_style(node, drawStyle):
    # Set the draw style of the node
    cmds.setAttr(node + ".drawStyle", drawStyle)

# Function to set the look style of an effector and its descendants
def set_look_style(node, lookStyle):
    # Set the look style of the node
    cmds.setAttr(node + ".look", lookStyle)

# Function to set the draw style of all bones in the hierarchy
def set_hierarchy_draw_style(drawStyle):
    # Get the selected objects
    selection = cmds.ls(selection=True)
    
    # Loop through each selected object
    for obj in selection:
        # Check if the object is a bone
        if cmds.nodeType(obj) == "joint":
            # Set the draw style of the bone and its descendants
            set_draw_style(obj, drawStyle)
            # Get the descendants of the bone
            descendants = cmds.listRelatives(obj, allDescendents=True)
            # Loop through each descendant
            for descendant in descendants:
                # Check if the descendant is a bone
                if cmds.nodeType(descendant) == "joint":
                    # Set the draw style of the descendant
                    set_draw_style(descendant, drawStyle)

# Function to set the look style of all effectors in the hierarchy
def set_hierarchy_look_style(lookStyle):
    # Get the selected objects
    selection = cmds.ls(selection=True)
    
    # Loop through each selected object
    for obj in selection:
        # Check if the object is an effector
        if cmds.nodeType(obj) == "hikFKJoint":
            # Set the look style of the effector and its descendants
            set_look_style(obj, lookStyle)
            # Get the descendants of the effector
            descendants = cmds.listRelatives(obj, allDescendents=True)
            # Loop through each descendant
            for descendant in descendants:
                # Check if the descendant is an effector
                if cmds.nodeType(descendant) == "hikFKJoint":
                    # Set the look style of the descendant
                    set_look_style(descendant, lookStyle)

# Function to generate a square or circle depending on selection from the dropdown at the location of the selected bone
def generate_shape(shape):
    # Get the selected objects
    selection = cmds.ls(selection=True)

    # Loop through each selected object
    for obj in selection:
        # Check if the object is a bone
        if cmds.nodeType(obj) == "joint":
            # Get the position of the bone
            position = cmds.xform(obj, query=True, worldSpace=True, translation=True)
            # Create the shape
            if shape == 0:
                cmds.circle(center=(position[0], position[1], position[2]), normal=(0, 1, 0), radius=50)
                # Set the pivot of the circle to the position of the bone
                cmds.xform(pivots=(position[0], position[1], position[2]))
                # Set the rotation of the circle to the rotation of the bone
                cmds.xform(rotation=(cmds.xform(obj, query=True, worldSpace=True, rotation=True)))
                # Freeze the transformations of the circle
                cmds.makeIdentity(apply=True, translate=True, rotate=True, scale=True)
                # Name the circle as CTRL_ + the name of the bone
                cmds.rename("CTRL_" + obj)
            elif shape == 1:
                cmds.curve(degree=1, point=[(position[0] - 1, position[1], position[2]), (position[0] + 1, position[1], position[2]), (position[0], position[1], position[2] + 1), (position[0], position[1], position[2] - 1), (position[0] - 1, position[1], position[2])])
    
# Function to get the selection from the dropdown menu
def get_index_from_menu(menuName):
    index = cmds.optionMenu(menuName, q=True, select=True) - 1
    return index

def set_local_axis(state):
    # Ensure that the state is a boolean
    state = int(bool(state))
    
    # Get the selected objects
    selection = cmds.ls(selection=True)

    # Loop through each selected object
    for obj in selection:
        # Set the displayLocalAxis attribute to the desired state for the current object
        cmds.setAttr(obj + ".displayLocalAxis", state)

        # Get the descendants of the object
        descendants = cmds.listRelatives(obj, allDescendents=True)

        # Loop through each descendant
        for descendant in descendants:
            # Check if the descendant is a transform node and not the current object itself
            if cmds.nodeType(descendant) == "transform" and descendant != obj:
                # Set the displayLocalAxis attribute to 0 for the descendant
                cmds.setAttr(descendant + ".displayLocalAxis", 0)

def set_rotate_order(rotate_order):
    # Get the selected objects
    selection = cmds.ls(selection=True)

    # Get the valid rotate order values
    valid_rotate_orders = {
        "xyz": 0,
        "yzx": 1,
        "zxy": 2,
        "xzy": 3,
        "yxz": 4,
        "zyx": 5
    }

    # Check if the given rotate order is valid
    if rotate_order in valid_rotate_orders:
        # Get the corresponding rotate order value
        rotate_order_value = valid_rotate_orders[rotate_order]

        # Loop through each selected object
        for obj in selection:
            # Check if the object is a transform node
            if cmds.nodeType(obj) == "joint":
                # Set the rotate order for the current object
                cmds.setAttr(obj + ".rotateOrder", rotate_order_value)

                print(f"Setting rotate order for {obj} to {rotate_order}")

                # Get the descendants of the object
                descendants = cmds.listRelatives(obj, allDescendents=True)

                # Loop through each descendant
                for descendant in descendants:
                    # Check if the descendant is a transform node and not the current object itself
                    if cmds.nodeType(descendant) == "joint" and descendant != obj:
                        # Set the rotate order for the descendant
                        cmds.setAttr(descendant + ".rotateOrder", rotate_order_value)
    else:
        cmds.warning(f"Invalid rotate order: {rotate_order}")

def set_joint_radius(radius):
    # Get the selected objects
    selection = cmds.ls(selection=True)

    # Loop through each selected object
    for obj in selection:
        # Check if the object is a transform node
        if cmds.nodeType(obj) == "joint":
            # Set the joint radius of the object
            cmds.setAttr(obj + ".radius", radius)

            # Get the descendants of the object
            descendants = cmds.listRelatives(obj, allDescendents=True)

            # Loop through each descendant
            for descendant in descendants:
                # Check if the descendant is a transform node and not the current object itself
                if cmds.nodeType(descendant) == "joint" and descendant != obj:
                    # Set the joint radius of the descendant
                    cmds.setAttr(descendant + ".radius", radius)

# Check if QApplication instance exists, otherwise create it
if not QApplication.instance():
    app = QApplication([])

    # Start the Qt event loop
    app.exec_()

# Create and show the UI
ui.show()