import { REFRESH_TOKEN_BUFFER_MS } from '../../constants/config'
import { toSafeError } from '../../utils/error'
import { getExpiresAtMsFromJWT } from '../../utils/jwtHelper'
import { now, setLongTimeout } from '../../utils/time'

async function callWithErrorHandling<ReturnType>(
	asyncFunctionThatMayThrow: () => Promise<ReturnType>
): Promise<
	| { data: ReturnType; errorMessage: null }
	| { data: null; errorMessage: string }
> {
	try {
		const successfullResult = await asyncFunctionThatMayThrow()
		return { data: successfullResult, errorMessage: null }
	} catch (unknownError) {
		const error = toSafeError(unknownError)
		return {
			data: null,
			errorMessage: error.message,
		}
	}
}

const TIMEOUT_CANCELLED_ERROR_MESSAGE = 'Timeout cancelled'
const REFRESH_FAILED_ERROR_MESSAGE = 'Refresh failed'

export interface AccessTokenRefreshCallbacks {
	onRefreshRequest?(): void
	onRefreshSuccess?(accessToken: string): void
	onRefreshFailure?(errorMessage: string): void
	onError?(errorMessage: string): void
}

enum AccessTokenRefreshManagerState {
	Stopped,
	Running,
	Ready,
}

export class AccessTokenRefreshManager {
	private state: AccessTokenRefreshManagerState
	private shouldStop: boolean = false
	private cancelTimeout: (() => void) | undefined
	private callbacks: AccessTokenRefreshCallbacks = {}

	constructor(
		private refreshAccessToken: () => Promise<string>,
		private accessToken: string | null = null
	) {
		this.state = AccessTokenRefreshManagerState.Ready
	}

	private async cancellablyWaitForExpiration(expiresAt: number) {
		const msUntilExpiresAt = expiresAt - now()
		const durationToWait = Math.max(
			msUntilExpiresAt - REFRESH_TOKEN_BUFFER_MS,
			0
		)
		const timeoutPromise = new Promise((resolve, reject) => {
			const cancelTimeout = setLongTimeout(resolve, durationToWait)
			this.cancelTimeout = () => {
				cancelTimeout()
				reject(new Error(TIMEOUT_CANCELLED_ERROR_MESSAGE))
			}
		})
		await timeoutPromise
		this.cancelTimeout = undefined
	}

	private async refreshAccessTokenBeforeItExpires(): Promise<void> {
		if (this.accessToken) {
			const expiresAt = getExpiresAtMsFromJWT(this.accessToken)
			if (!expiresAt) {
				throw new Error('Unable to get expiration time of access token')
			}
			await this.cancellablyWaitForExpiration(expiresAt)
		}

		if (this.shouldStop) {
			return
		}

		this.callbacks.onRefreshRequest?.()

		const { data: newAccessToken, errorMessage: refreshErrorMessage } =
			await callWithErrorHandling(this.refreshAccessToken)

		if (this.shouldStop) {
			return
		}

		if (refreshErrorMessage) {
			this.callbacks.onRefreshFailure?.(refreshErrorMessage)
			throw new Error(REFRESH_FAILED_ERROR_MESSAGE)
		}

		this.accessToken = newAccessToken

		if (!this.accessToken) {
			throw new Error('Refreshing accessToken failed')
		}

		this.callbacks.onRefreshSuccess?.(this.accessToken)
	}

	private async continuallyRefresh() {
		let hasEncounteredFatalError = false
		while (!this.shouldStop && !hasEncounteredFatalError) {
			try {
				await this.refreshAccessTokenBeforeItExpires()
			} catch (unknownError) {
				const error = toSafeError(unknownError)
				switch (error.message) {
					case TIMEOUT_CANCELLED_ERROR_MESSAGE:
						// Defer to the shouldStop flag,
						// which may have changed back to false
						hasEncounteredFatalError = false
						break
					case REFRESH_FAILED_ERROR_MESSAGE:
						hasEncounteredFatalError = true
						break
					default:
						this.callbacks.onError?.(error.message)
						hasEncounteredFatalError = true
				}
			}
		}
		this.state = AccessTokenRefreshManagerState.Stopped
		this.shouldStop = false
	}

	public maintainFreshAccessToken(callbacks: AccessTokenRefreshCallbacks) {
		this.callbacks = callbacks
		if (this.shouldStop) {
			this.shouldStop = false
		}
		if (!this.isRunning()) {
			this.continuallyRefresh()
			this.state = AccessTokenRefreshManagerState.Running
		}
	}

	public stopRefreshingAccessToken() {
		this.shouldStop = true
		this.cancelTimeout?.()
	}

	public isRunning() {
		return this.state === AccessTokenRefreshManagerState.Running
	}
}
