import * as React from 'react'
import useMutationObserver from '@rooks/use-mutation-observer'

/**
 * useClassNameWatch hook
 * Will return true if any of the items in the items array match, false otherwise. If matchAll is true, will return true if all items in the items array match, false otherwise
 *
 * @param ref e.g React.useRef(window.document.body)
 * @param classNames array of className strings
 * @param matchAll if true, all items must be present in the classList
 * @returns boolean
 */
export const useClassNameWatch = (
  ref: React.MutableRefObject<HTMLElement | null>,
  classNames: string[],
  matchAll: boolean = false
) => {
  // set the initial state
  const targetClassList = ref.current?.classList || []
  const targetClassNames = Array.from(targetClassList) as string[]

  const initialState = {
    result: isClassNameMatch(targetClassNames, classNames, matchAll),
    matchingClassNames: intersection(targetClassNames, classNames)
  }

  const [result, setResult] = React.useState<{
    result: boolean
    matchingClassNames: string[]
  }>(initialState)

  const mutationCallback = React.useCallback(
    (mutations: MutationRecord[]) => {
      for (const mutation of mutations) {
        if (mutation.type === 'attributes') {
          if (mutation.attributeName === 'class') {
            const target = mutation.target as HTMLElement
            const classList: DOMTokenList = target.classList
            const targetClassNames = Array.from(classList) as string[]
            const result = isClassNameMatch(
              targetClassNames,
              classNames,
              matchAll
            )
            const matchingClassNames = intersection(
              targetClassNames,
              classNames
            )
            setResult({ result, matchingClassNames })
          }
        }
      }
    },
    [classNames, matchAll]
  )

  useMutationObserver(ref, mutationCallback, {
    attributes: true,
    attributeFilter: ['class']
  })

  return result
}

//-------------------------------------------------------------
// Helpers
//-------------------------------------------------------------
function intersection(
  targetClassNames: string[],
  classNames: string[]
): string[] {
  return classNames.filter((className) => targetClassNames.includes(className))
}

function isClassNameMatch(
  targetClassNames: string[],
  classNames: string[],
  matchAll: boolean
) {
  return matchAll
    ? targetClassNames.every((className) => classNames.includes(className))
    : targetClassNames.some((className) => classNames.includes(className))
}
