From 123b9fb5356ee2147c66040fbf1bde98a9cc5f65 Mon Sep 17 00:00:00 2001 From: Krishna Acondy Date: Mon, 12 Jul 2021 20:31:17 +0100 Subject: [PATCH] chore(refactor): split up and add tests for core functionality --- src/SASViyaApiClient.ts | 4 +- src/api/viya/executeScript.ts | 2 +- src/api/viya/pollJobState.ts | 56 +---- src/api/viya/saveLog.ts | 40 ++++ src/api/viya/spec/executeScript.spec.ts | 20 +- src/api/viya/spec/pollJobState.spec.ts | 266 ++++++++++++++++++++++++ src/api/viya/spec/saveLog.spec.ts | 72 +++++++ src/auth/getAccessToken.ts | 49 +++++ src/auth/getTokens.ts | 40 ++++ src/auth/refreshTokens.ts | 49 +++++ src/auth/spec/getTokens.spec.ts | 79 +++++++ src/auth/spec/mockResponses.ts | 22 ++ src/auth/spec/refreshTokens.spec.ts | 75 +++++++ src/auth/tokens.ts | 122 ----------- 14 files changed, 717 insertions(+), 179 deletions(-) create mode 100644 src/api/viya/saveLog.ts create mode 100644 src/api/viya/spec/pollJobState.spec.ts create mode 100644 src/api/viya/spec/saveLog.spec.ts create mode 100644 src/auth/getAccessToken.ts create mode 100644 src/auth/getTokens.ts create mode 100644 src/auth/refreshTokens.ts create mode 100644 src/auth/spec/getTokens.spec.ts create mode 100644 src/auth/spec/refreshTokens.spec.ts delete mode 100644 src/auth/tokens.ts diff --git a/src/SASViyaApiClient.ts b/src/SASViyaApiClient.ts index 78d92b5..ee439f2 100644 --- a/src/SASViyaApiClient.ts +++ b/src/SASViyaApiClient.ts @@ -19,9 +19,11 @@ import { isAuthorizeFormRequired } from './auth/isAuthorizeFormRequired' import { RequestClient } from './request/RequestClient' import { prefixMessage } from '@sasjs/utils/error' import { pollJobState } from './api/viya/pollJobState' -import { getAccessToken, getTokens, refreshTokens } from './auth/tokens' +import { getTokens } from './auth/getTokens' import { uploadTables } from './api/viya/uploadTables' import { executeScript } from './api/viya/executeScript' +import { getAccessToken } from './auth/getAccessToken' +import { refreshTokens } from './auth/refreshTokens' /** * A client for interfacing with the SAS Viya REST API. diff --git a/src/api/viya/executeScript.ts b/src/api/viya/executeScript.ts index 0abe466..c75a3e5 100644 --- a/src/api/viya/executeScript.ts +++ b/src/api/viya/executeScript.ts @@ -7,7 +7,7 @@ import { ComputeJobExecutionError, NotFoundError } from '../..' -import { getTokens } from '../../auth/tokens' +import { getTokens } from '../../auth/getTokens' import { RequestClient } from '../../request/RequestClient' import { SessionManager } from '../../SessionManager' import { isRelativePath, fetchLogByChunks } from '../../utils' diff --git a/src/api/viya/pollJobState.ts b/src/api/viya/pollJobState.ts index 63796ce..8d95c42 100644 --- a/src/api/viya/pollJobState.ts +++ b/src/api/viya/pollJobState.ts @@ -1,11 +1,10 @@ -import { AuthConfig } from '@sasjs/utils' +import { AuthConfig } from '@sasjs/utils/types' import { prefixMessage } from '@sasjs/utils/error' import { generateTimestamp } from '@sasjs/utils/time' -import { createFile } from '@sasjs/utils/file' import { Job, PollOptions } from '../..' -import { getTokens } from '../../auth/tokens' +import { getTokens } from '../../auth/getTokens' import { RequestClient } from '../../request/RequestClient' -import { fetchLogByChunks } from '../../utils' +import { saveLog } from './saveLog' export async function pollJobState( requestClient: RequestClient, @@ -48,7 +47,7 @@ export async function pollJobState( } const stateLink = postedJob.links.find((l: any) => l.rel === 'state') if (!stateLink) { - return Promise.reject(`Job state link was not found.`) + throw new Error(`Job state link was not found.`) } const { result: state } = await requestClient @@ -72,7 +71,7 @@ export async function pollJobState( return Promise.resolve(currentState) } - return new Promise(async (resolve, _) => { + return new Promise(async (resolve, reject) => { let printedState = '' const interval = setInterval(async () => { @@ -98,9 +97,12 @@ export async function pollJobState( .catch((err) => { errorCount++ if (pollCount >= maxPollCount || errorCount >= maxErrorCount) { - throw prefixMessage( - err, - 'Error while getting job state after interval. ' + clearInterval(interval) + reject( + prefixMessage( + err, + 'Error while getting job state after interval. ' + ) ) } logger.error( @@ -143,39 +145,3 @@ export async function pollJobState( }, pollInterval) }) } - -async function saveLog( - job: Job, - requestClient: RequestClient, - shouldSaveLog: boolean, - logFilePath: string, - accessToken?: string -) { - if (!shouldSaveLog) { - return - } - - if (!accessToken) { - throw new Error( - `Logs for job ${job.id} cannot be fetched without a valid access token.` - ) - } - - const logger = process.logger || console - const jobLogUrl = job.links.find((l) => l.rel === 'log') - - if (!jobLogUrl) { - throw new Error(`Log URL for job ${job.id} was not found.`) - } - - const logCount = job.logStatistics?.lineCount ?? 1000000 - const log = await fetchLogByChunks( - requestClient, - accessToken, - `${jobLogUrl.href}/content`, - logCount - ) - - logger.info(`Writing logs to ${logFilePath}`) - await createFile(logFilePath, log) -} diff --git a/src/api/viya/saveLog.ts b/src/api/viya/saveLog.ts new file mode 100644 index 0000000..9200930 --- /dev/null +++ b/src/api/viya/saveLog.ts @@ -0,0 +1,40 @@ +import { createFile } from '@sasjs/utils/file' +import { Job } from '../..' +import { RequestClient } from '../../request/RequestClient' +import { fetchLogByChunks } from '../../utils' + +export async function saveLog( + job: Job, + requestClient: RequestClient, + shouldSaveLog: boolean, + logFilePath: string, + accessToken?: string +) { + if (!shouldSaveLog) { + return + } + + if (!accessToken) { + throw new Error( + `Logs for job ${job.id} cannot be fetched without a valid access token.` + ) + } + + const logger = process.logger || console + const jobLogUrl = job.links.find((l) => l.rel === 'log') + + if (!jobLogUrl) { + throw new Error(`Log URL for job ${job.id} was not found.`) + } + + const logCount = job.logStatistics?.lineCount ?? 1000000 + const log = await fetchLogByChunks( + requestClient, + accessToken, + `${jobLogUrl.href}/content`, + logCount + ) + + logger.info(`Writing logs to ${logFilePath}`) + await createFile(logFilePath, log) +} diff --git a/src/api/viya/spec/executeScript.spec.ts b/src/api/viya/spec/executeScript.spec.ts index bca8fdb..314c70b 100644 --- a/src/api/viya/spec/executeScript.spec.ts +++ b/src/api/viya/spec/executeScript.spec.ts @@ -4,7 +4,7 @@ import { executeScript } from '../executeScript' import { mockSession, mockAuthConfig, mockJob } from './mockResponses' import * as pollJobStateModule from '../pollJobState' import * as uploadTablesModule from '../uploadTables' -import * as tokensModule from '../../../auth/tokens' +import * as getTokensModule from '../../../auth/getTokens' import * as formatDataModule from '../../../utils/formatDataForRequest' import * as fetchLogsModule from '../../../utils/fetchLogByChunks' import { PollOptions } from '../../../types' @@ -35,7 +35,7 @@ describe('executeScript', () => { 'test context' ) - expect(tokensModule.getTokens).not.toHaveBeenCalled() + expect(getTokensModule.getTokens).not.toHaveBeenCalled() }) it('should try to get fresh tokens if an authConfig is provided', async () => { @@ -49,7 +49,7 @@ describe('executeScript', () => { mockAuthConfig ) - expect(tokensModule.getTokens).toHaveBeenCalledWith( + expect(getTokensModule.getTokens).toHaveBeenCalledWith( requestClient, mockAuthConfig ) @@ -82,7 +82,7 @@ describe('executeScript', () => { 'test context' ).catch((e) => e) - expect(error.includes('Error while getting session.')).toBeTruthy() + expect(error).toContain('Error while getting session.') }) it('should fetch the PID when printPid is true', async () => { @@ -130,7 +130,7 @@ describe('executeScript', () => { true ).catch((e) => e) - expect(error.includes('Error while getting session variable.')).toBeTruthy() + expect(error).toContain('Error while getting session variable.') }) it('should use the file upload approach when data contains semicolons', async () => { @@ -300,7 +300,7 @@ describe('executeScript', () => { ).catch((e) => e) console.log(error) - expect(error.includes('Error while posting job')).toBeTruthy() + expect(error).toContain('Error while posting job') }) it('should immediately return the session when waitForResult is false', async () => { @@ -371,7 +371,7 @@ describe('executeScript', () => { true ).catch((e) => e) - expect(error.includes('Error while polling job status.')).toBeTruthy() + expect(error).toContain('Error while polling job status.') }) it('should fetch the log and append it to the error in case of a 5113 error code', async () => { @@ -626,7 +626,7 @@ describe('executeScript', () => { true ).catch((e) => e) - expect(error.includes('Error while clearing session.')).toBeTruthy() + expect(error).toContain('Error while clearing session.') }) }) @@ -634,7 +634,7 @@ const setupMocks = () => { jest.restoreAllMocks() jest.mock('../../../request/RequestClient') jest.mock('../../../SessionManager') - jest.mock('../../../auth/tokens') + jest.mock('../../../auth/getTokens') jest.mock('../pollJobState') jest.mock('../uploadTables') jest.mock('../../../utils/formatDataForRequest') @@ -650,7 +650,7 @@ const setupMocks = () => { .spyOn(requestClient, 'delete') .mockImplementation(() => Promise.resolve({ result: {}, etag: '' })) jest - .spyOn(tokensModule, 'getTokens') + .spyOn(getTokensModule, 'getTokens') .mockImplementation(() => Promise.resolve(mockAuthConfig)) jest .spyOn(pollJobStateModule, 'pollJobState') diff --git a/src/api/viya/spec/pollJobState.spec.ts b/src/api/viya/spec/pollJobState.spec.ts new file mode 100644 index 0000000..23ac560 --- /dev/null +++ b/src/api/viya/spec/pollJobState.spec.ts @@ -0,0 +1,266 @@ +import { RequestClient } from '../../../request/RequestClient' +import { mockAuthConfig, mockJob } from './mockResponses' +import { pollJobState } from '../pollJobState' +import * as getTokensModule from '../../../auth/getTokens' +import * as saveLogModule from '../saveLog' +import { PollOptions } from '../../../types' +import { Logger, LogLevel } from '@sasjs/utils' + +const requestClient = new (>RequestClient)() +const defaultPollOptions: PollOptions = { + maxPollCount: 100, + pollInterval: 500, + streamLog: false +} + +describe('pollJobState', () => { + beforeEach(() => { + ;(process as any).logger = new Logger(LogLevel.Off) + setupMocks() + }) + + it('should get valid tokens if the authConfig has been provided', async () => { + await pollJobState( + requestClient, + mockJob, + false, + 'test', + mockAuthConfig, + defaultPollOptions + ) + + expect(getTokensModule.getTokens).toHaveBeenCalledWith( + requestClient, + mockAuthConfig + ) + }) + + it('should not attempt to get tokens if the authConfig has not been provided', async () => { + await pollJobState( + requestClient, + mockJob, + false, + 'test', + undefined, + defaultPollOptions + ) + + expect(getTokensModule.getTokens).not.toHaveBeenCalled() + }) + + it('should throw an error if the job does not have a state link', async () => { + const error = await pollJobState( + requestClient, + { ...mockJob, links: mockJob.links.filter((l) => l.rel !== 'state') }, + false, + 'test', + undefined, + defaultPollOptions + ).catch((e) => e) + + expect((error as Error).message).toContain('Job state link was not found.') + }) + + it('should attempt to refresh tokens before each poll', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => + Promise.resolve({ result: 'running', etag: '' }) + ) + .mockImplementation(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + + await pollJobState( + requestClient, + mockJob, + false, + 'test', + mockAuthConfig, + defaultPollOptions + ) + + expect(getTokensModule.getTokens).toHaveBeenCalledTimes(3) + }) + + it('should attempt to fetch and save the log after each poll', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => + Promise.resolve({ result: 'running', etag: '' }) + ) + .mockImplementation(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + + await pollJobState( + requestClient, + mockJob, + false, + 'test', + mockAuthConfig, + defaultPollOptions + ) + + expect(saveLogModule.saveLog).toHaveBeenCalledTimes(2) + }) + + it('should return the current status when the max poll count is reached', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => + Promise.resolve({ result: 'running', etag: '' }) + ) + + const state = await pollJobState( + requestClient, + mockJob, + false, + 'test', + mockAuthConfig, + { + ...defaultPollOptions, + maxPollCount: 1 + } + ) + + expect(state).toEqual('running') + }) + + it('should continue polling until the job completes or errors', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => + Promise.resolve({ result: 'running', etag: '' }) + ) + .mockImplementation(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + + const state = await pollJobState( + requestClient, + mockJob, + false, + 'test', + undefined, + defaultPollOptions + ) + + expect(requestClient.get).toHaveBeenCalledTimes(4) + expect(state).toEqual('completed') + }) + + it('should print the state to the console when debug is on', async () => { + jest.spyOn((process as any).logger, 'info') + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => + Promise.resolve({ result: 'running', etag: '' }) + ) + .mockImplementation(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + + await pollJobState( + requestClient, + mockJob, + true, + 'test', + undefined, + defaultPollOptions + ) + + expect((process as any).logger.info).toHaveBeenCalledTimes(4) + expect((process as any).logger.info).toHaveBeenNthCalledWith( + 1, + 'Polling job status...' + ) + expect((process as any).logger.info).toHaveBeenNthCalledWith( + 2, + 'Current job state: running' + ) + expect((process as any).logger.info).toHaveBeenNthCalledWith( + 3, + 'Polling job status...' + ) + expect((process as any).logger.info).toHaveBeenNthCalledWith( + 4, + 'Current job state: completed' + ) + }) + + it('should continue polling when there is a single error in between', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementationOnce(() => + Promise.resolve({ result: 'pending', etag: '' }) + ) + .mockImplementationOnce(() => Promise.reject('Status Error')) + .mockImplementationOnce(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + + const state = await pollJobState( + requestClient, + mockJob, + false, + 'test', + undefined, + defaultPollOptions + ) + + expect(requestClient.get).toHaveBeenCalledTimes(3) + expect(state).toEqual('completed') + }) + + it('should throw an error when the error count exceeds the set value of 5', async () => { + jest + .spyOn(requestClient, 'get') + .mockImplementation(() => Promise.reject('Status Error')) + + const error = await pollJobState( + requestClient, + mockJob, + false, + 'test', + undefined, + defaultPollOptions + ).catch((e) => e) + + expect(error).toContain('Error while getting job state after interval.') + }) +}) + +const setupMocks = () => { + jest.restoreAllMocks() + jest.mock('../../../request/RequestClient') + jest.mock('../../../auth/getTokens') + jest.mock('../saveLog') + + jest + .spyOn(requestClient, 'get') + .mockImplementation(() => + Promise.resolve({ result: 'completed', etag: '' }) + ) + jest + .spyOn(getTokensModule, 'getTokens') + .mockImplementation(() => Promise.resolve(mockAuthConfig)) + jest + .spyOn(saveLogModule, 'saveLog') + .mockImplementation(() => Promise.resolve()) +} diff --git a/src/api/viya/spec/saveLog.spec.ts b/src/api/viya/spec/saveLog.spec.ts new file mode 100644 index 0000000..c4b8b9d --- /dev/null +++ b/src/api/viya/spec/saveLog.spec.ts @@ -0,0 +1,72 @@ +import { Logger, LogLevel } from '@sasjs/utils' +import * as fileModule from '@sasjs/utils/file' +import { RequestClient } from '../../../request/RequestClient' +import * as fetchLogsModule from '../../../utils/fetchLogByChunks' +import { saveLog } from '../saveLog' +import { mockJob } from './mockResponses' + +const requestClient = new (>RequestClient)() + +describe('saveLog', () => { + beforeEach(() => { + ;(process as any).logger = new Logger(LogLevel.Off) + setupMocks() + }) + + it('should return immediately if shouldSaveLog is false', async () => { + await saveLog(mockJob, requestClient, false, '/test', 't0k3n') + + expect(fetchLogsModule.fetchLogByChunks).not.toHaveBeenCalled() + expect(fileModule.createFile).not.toHaveBeenCalled() + }) + + it('should throw an error when a valid access token is not provided', async () => { + const error = await saveLog(mockJob, requestClient, true, '/test').catch( + (e) => e + ) + + expect(error.message).toContain( + `Logs for job ${mockJob.id} cannot be fetched without a valid access token.` + ) + }) + + it('should throw an error when the log URL is not available', async () => { + const error = await saveLog( + { ...mockJob, links: mockJob.links.filter((l) => l.rel !== 'log') }, + requestClient, + true, + '/test', + 't0k3n' + ).catch((e) => e) + + expect(error.message).toContain( + `Log URL for job ${mockJob.id} was not found.` + ) + }) + + it('should fetch and save logs to the given path', async () => { + await saveLog(mockJob, requestClient, true, '/test', 't0k3n') + + expect(fetchLogsModule.fetchLogByChunks).toHaveBeenCalledWith( + requestClient, + 't0k3n', + '/log/content', + 100 + ) + expect(fileModule.createFile).toHaveBeenCalledWith('/test', 'Test Log') + }) +}) + +const setupMocks = () => { + jest.restoreAllMocks() + jest.mock('../../../request/RequestClient') + jest.mock('../../../utils/fetchLogByChunks') + jest.mock('@sasjs/utils') + + jest + .spyOn(fetchLogsModule, 'fetchLogByChunks') + .mockImplementation(() => Promise.resolve('Test Log')) + jest + .spyOn(fileModule, 'createFile') + .mockImplementation(() => Promise.resolve()) +} diff --git a/src/auth/getAccessToken.ts b/src/auth/getAccessToken.ts new file mode 100644 index 0000000..0b11340 --- /dev/null +++ b/src/auth/getAccessToken.ts @@ -0,0 +1,49 @@ +import { SasAuthResponse } from '@sasjs/utils' +import * as NodeFormData from 'form-data' +import { RequestClient } from '../request/RequestClient' + +/** + * Exchanges the auth code for an access token for the given client. + * @param requestClient - the pre-configured HTTP request client + * @param clientId - the client ID to authenticate with. + * @param clientSecret - the client secret to authenticate with. + * @param authCode - the auth code received from the server. + */ +export async function getAccessToken( + requestClient: RequestClient, + clientId: string, + clientSecret: string, + authCode: string +): Promise { + const url = '/SASLogon/oauth/token' + let token + if (typeof Buffer === 'undefined') { + token = btoa(clientId + ':' + clientSecret) + } else { + token = Buffer.from(clientId + ':' + clientSecret).toString('base64') + } + const headers = { + Authorization: 'Basic ' + token + } + + let formData + if (typeof FormData === 'undefined') { + formData = new NodeFormData() + } else { + formData = new FormData() + } + formData.append('grant_type', 'authorization_code') + formData.append('code', authCode) + + const authResponse = await requestClient + .post( + url, + formData, + undefined, + 'multipart/form-data; boundary=' + (formData as any)._boundary, + headers + ) + .then((res) => res.result as SasAuthResponse) + + return authResponse +} diff --git a/src/auth/getTokens.ts b/src/auth/getTokens.ts new file mode 100644 index 0000000..031c6a3 --- /dev/null +++ b/src/auth/getTokens.ts @@ -0,0 +1,40 @@ +import { + AuthConfig, + isAccessTokenExpiring, + isRefreshTokenExpiring, + hasTokenExpired +} from '@sasjs/utils' +import { RequestClient } from '../request/RequestClient' +import { refreshTokens } from './refreshTokens' + +/** + * Returns the auth configuration, refreshing the tokens if necessary. + * @param requestClient - the pre-configured HTTP request client + * @param authConfig - an object containing a client ID, secret, access token and refresh token + */ +export async function getTokens( + requestClient: RequestClient, + authConfig: AuthConfig +): Promise { + const logger = process.logger || console + let { access_token, refresh_token, client, secret } = authConfig + if ( + isAccessTokenExpiring(access_token) || + isRefreshTokenExpiring(refresh_token) + ) { + if (hasTokenExpired(refresh_token)) { + const error = + 'Unable to obtain new access token. Your refresh token has expired.' + logger.error(error) + throw new Error(error) + } + logger.info('Refreshing access and refresh tokens.') + ;({ access_token, refresh_token } = await refreshTokens( + requestClient, + client, + secret, + refresh_token + )) + } + return { access_token, refresh_token, client, secret } +} diff --git a/src/auth/refreshTokens.ts b/src/auth/refreshTokens.ts new file mode 100644 index 0000000..5871d63 --- /dev/null +++ b/src/auth/refreshTokens.ts @@ -0,0 +1,49 @@ +import { SasAuthResponse } from '@sasjs/utils/types' +import { prefixMessage } from '@sasjs/utils/error' +import * as NodeFormData from 'form-data' +import { RequestClient } from '../request/RequestClient' + +/** + * Exchanges the refresh token for an access token for the given client. + * @param requestClient - the pre-configured HTTP request client + * @param clientId - the client ID to authenticate with. + * @param clientSecret - the client secret to authenticate with. + * @param authCode - the refresh token received from the server. + */ +export async function refreshTokens( + requestClient: RequestClient, + clientId: string, + clientSecret: string, + refreshToken: string +) { + const url = '/SASLogon/oauth/token' + let token + token = + typeof Buffer === 'undefined' + ? btoa(clientId + ':' + clientSecret) + : Buffer.from(clientId + ':' + clientSecret).toString('base64') + + const headers = { + Authorization: 'Basic ' + token + } + + const formData = + typeof FormData === 'undefined' ? new NodeFormData() : new FormData() + formData.append('grant_type', 'refresh_token') + formData.append('refresh_token', refreshToken) + + const authResponse = await requestClient + .post( + url, + formData, + undefined, + 'multipart/form-data; boundary=' + (formData as any)._boundary, + headers + ) + .then((res) => res.result) + .catch((err) => { + throw prefixMessage(err, 'Error while refreshing tokens') + }) + + return authResponse +} diff --git a/src/auth/spec/getTokens.spec.ts b/src/auth/spec/getTokens.spec.ts new file mode 100644 index 0000000..de4397c --- /dev/null +++ b/src/auth/spec/getTokens.spec.ts @@ -0,0 +1,79 @@ +import { AuthConfig } from '@sasjs/utils' +import * as refreshTokensModule from '../refreshTokens' +import { generateToken, mockAuthResponse } from './mockResponses' +import { getTokens } from '../getTokens' +import { RequestClient } from '../../request/RequestClient' + +const requestClient = new (>RequestClient)() + +describe('getTokens', () => { + it('should attempt to refresh tokens if the access token is expiring', async () => { + setupMocks() + const access_token = generateToken(30) + const refresh_token = generateToken(86400000) + const authConfig: AuthConfig = { + access_token, + refresh_token, + client: 'cl13nt', + secret: 's3cr3t' + } + + await getTokens(requestClient, authConfig) + + expect(refreshTokensModule.refreshTokens).toHaveBeenCalledWith( + requestClient, + authConfig.client, + authConfig.secret, + authConfig.refresh_token + ) + }) + + it('should attempt to refresh tokens if the refresh token is expiring', async () => { + setupMocks() + const access_token = generateToken(86400000) + const refresh_token = generateToken(30) + const authConfig: AuthConfig = { + access_token, + refresh_token, + client: 'cl13nt', + secret: 's3cr3t' + } + + await getTokens(requestClient, authConfig) + + expect(refreshTokensModule.refreshTokens).toHaveBeenCalledWith( + requestClient, + authConfig.client, + authConfig.secret, + authConfig.refresh_token + ) + }) + + it('should throw an error if the refresh token has already expired', async () => { + setupMocks() + const access_token = generateToken(86400000) + const refresh_token = generateToken(-36000) + const authConfig: AuthConfig = { + access_token, + refresh_token, + client: 'cl13nt', + secret: 's3cr3t' + } + const expectedError = + 'Unable to obtain new access token. Your refresh token has expired.' + + const error = await getTokens(requestClient, authConfig).catch((e) => e) + + expect(error.message).toEqual(expectedError) + }) +}) + +const setupMocks = () => { + jest.restoreAllMocks() + jest.mock('../../request/RequestClient') + jest.mock('../refreshTokens') + + jest + .spyOn(refreshTokensModule, 'refreshTokens') + .mockImplementation(() => Promise.resolve(mockAuthResponse)) +} diff --git a/src/auth/spec/mockResponses.ts b/src/auth/spec/mockResponses.ts index 4ffcfb2..e15391a 100644 --- a/src/auth/spec/mockResponses.ts +++ b/src/auth/spec/mockResponses.ts @@ -1,2 +1,24 @@ +import { SasAuthResponse } from '@sasjs/utils/types' + export const mockLoginAuthoriseRequiredResponse = `
` export const mockLoginSuccessResponse = `You have signed in` + +export const mockAuthResponse: SasAuthResponse = { + access_token: 'acc355', + refresh_token: 'r3fr35h', + id_token: 'id', + token_type: 'bearer', + expires_in: new Date().valueOf(), + scope: 'default', + jti: 'test' +} + +export const generateToken = (timeToLiveSeconds: number): string => { + const exp = + new Date(new Date().getTime() + timeToLiveSeconds * 1000).getTime() / 1000 + const header = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9' + const payload = Buffer.from(JSON.stringify({ exp })).toString('base64') + const signature = '4-iaDojEVl0pJQMjrbM1EzUIfAZgsbK_kgnVyVxFSVo' + const token = `${header}.${payload}.${signature}` + return token +} diff --git a/src/auth/spec/refreshTokens.spec.ts b/src/auth/spec/refreshTokens.spec.ts new file mode 100644 index 0000000..8f88f1d --- /dev/null +++ b/src/auth/spec/refreshTokens.spec.ts @@ -0,0 +1,75 @@ +import { AuthConfig } from '@sasjs/utils' +import * as NodeFormData from 'form-data' +import { generateToken, mockAuthResponse } from './mockResponses' +import { RequestClient } from '../../request/RequestClient' +import { refreshTokens } from '../refreshTokens' + +const requestClient = new (>RequestClient)() + +describe('refreshTokens', () => { + it('should attempt to refresh tokens', async () => { + setupMocks() + const access_token = generateToken(30) + const refresh_token = generateToken(30) + const authConfig: AuthConfig = { + access_token, + refresh_token, + client: 'cl13nt', + secret: 's3cr3t' + } + jest + .spyOn(requestClient, 'post') + .mockImplementation(() => + Promise.resolve({ result: mockAuthResponse, etag: '' }) + ) + const token = Buffer.from( + authConfig.client + ':' + authConfig.secret + ).toString('base64') + + await refreshTokens( + requestClient, + authConfig.client, + authConfig.secret, + authConfig.refresh_token + ) + + expect(requestClient.post).toHaveBeenCalledWith( + '/SASLogon/oauth/token', + expect.any(NodeFormData), + undefined, + expect.stringContaining('multipart/form-data; boundary='), + { + Authorization: 'Basic ' + token + } + ) + }) + + it('should handle errors while refreshing tokens', async () => { + setupMocks() + const access_token = generateToken(30) + const refresh_token = generateToken(30) + const authConfig: AuthConfig = { + access_token, + refresh_token, + client: 'cl13nt', + secret: 's3cr3t' + } + jest + .spyOn(requestClient, 'post') + .mockImplementation(() => Promise.reject('Token Error')) + + const error = await refreshTokens( + requestClient, + authConfig.client, + authConfig.secret, + authConfig.refresh_token + ).catch((e) => e) + + expect(error).toContain('Error while refreshing tokens') + }) +}) + +const setupMocks = () => { + jest.restoreAllMocks() + jest.mock('../../request/RequestClient') +} diff --git a/src/auth/tokens.ts b/src/auth/tokens.ts deleted file mode 100644 index bc940e2..0000000 --- a/src/auth/tokens.ts +++ /dev/null @@ -1,122 +0,0 @@ -import { - AuthConfig, - isAccessTokenExpiring, - isRefreshTokenExpiring, - SasAuthResponse -} from '@sasjs/utils' -import * as NodeFormData from 'form-data' -import { RequestClient } from '../request/RequestClient' - -/** - * Exchanges the auth code for an access token for the given client. - * @param requestClient - the pre-configured HTTP request client - * @param clientId - the client ID to authenticate with. - * @param clientSecret - the client secret to authenticate with. - * @param authCode - the auth code received from the server. - */ -export async function getAccessToken( - requestClient: RequestClient, - clientId: string, - clientSecret: string, - authCode: string -): Promise { - const url = '/SASLogon/oauth/token' - let token - if (typeof Buffer === 'undefined') { - token = btoa(clientId + ':' + clientSecret) - } else { - token = Buffer.from(clientId + ':' + clientSecret).toString('base64') - } - const headers = { - Authorization: 'Basic ' + token - } - - let formData - if (typeof FormData === 'undefined') { - formData = new NodeFormData() - } else { - formData = new FormData() - } - formData.append('grant_type', 'authorization_code') - formData.append('code', authCode) - - const authResponse = await requestClient - .post( - url, - formData, - undefined, - 'multipart/form-data; boundary=' + (formData as any)._boundary, - headers - ) - .then((res) => res.result as SasAuthResponse) - - return authResponse -} - -/** - * Returns the auth configuration, refreshing the tokens if necessary. - * @param requestClient - the pre-configured HTTP request client - * @param authConfig - an object containing a client ID, secret, access token and refresh token - */ -export async function getTokens( - requestClient: RequestClient, - authConfig: AuthConfig -): Promise { - const logger = process.logger || console - let { access_token, refresh_token, client, secret } = authConfig - if ( - isAccessTokenExpiring(access_token) || - isRefreshTokenExpiring(refresh_token) - ) { - logger.info('Refreshing access and refresh tokens.') - ;({ access_token, refresh_token } = await refreshTokens( - requestClient, - client, - secret, - refresh_token - )) - } - return { access_token, refresh_token, client, secret } -} - -/** - * Exchanges the refresh token for an access token for the given client. - * @param requestClient - the pre-configured HTTP request client - * @param clientId - the client ID to authenticate with. - * @param clientSecret - the client secret to authenticate with. - * @param authCode - the refresh token received from the server. - */ -export async function refreshTokens( - requestClient: RequestClient, - clientId: string, - clientSecret: string, - refreshToken: string -) { - const url = '/SASLogon/oauth/token' - let token - if (typeof Buffer === 'undefined') { - token = btoa(clientId + ':' + clientSecret) - } else { - token = Buffer.from(clientId + ':' + clientSecret).toString('base64') - } - const headers = { - Authorization: 'Basic ' + token - } - - const formData = - typeof FormData === 'undefined' ? new NodeFormData() : new FormData() - formData.append('grant_type', 'refresh_token') - formData.append('refresh_token', refreshToken) - - const authResponse = await requestClient - .post( - url, - formData, - undefined, - 'multipart/form-data; boundary=' + (formData as any)._boundary, - headers - ) - .then((res) => res.result) - - return authResponse -}